diff --git a/cmd/agent/cmd.go b/cmd/agent/cmd.go deleted file mode 100644 index 1d75369..0000000 --- a/cmd/agent/cmd.go +++ /dev/null @@ -1,437 +0,0 @@ -package agent - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - "time" - - "forge.lthn.ai/core/cli/pkg/cli" - agentic "forge.lthn.ai/core/agent/pkg/lifecycle" - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/agentci" - "forge.lthn.ai/core/config" -) - -func init() { - cli.RegisterCommands(AddAgentCommands) -} - -// Style aliases from shared package. -var ( - successStyle = cli.SuccessStyle - errorStyle = cli.ErrorStyle - dimStyle = cli.DimStyle - taskPriorityMediumStyle = cli.NewStyle().Foreground(cli.ColourAmber500) -) - -const defaultWorkDir = "ai-work" - -// AddAgentCommands registers the 'agent' subcommand group under 'ai'. -func AddAgentCommands(parent *cli.Command) { - agentCmd := &cli.Command{ - Use: "agent", - Short: "Manage AgentCI dispatch targets", - } - - agentCmd.AddCommand(agentAddCmd()) - agentCmd.AddCommand(agentListCmd()) - agentCmd.AddCommand(agentStatusCmd()) - agentCmd.AddCommand(agentLogsCmd()) - agentCmd.AddCommand(agentSetupCmd()) - agentCmd.AddCommand(agentRemoveCmd()) - agentCmd.AddCommand(agentFleetCmd()) - - parent.AddCommand(agentCmd) -} - -func loadConfig() (*config.Config, error) { - return config.New() -} - -func agentAddCmd() *cli.Command { - cmd := &cli.Command{ - Use: "add ", - Short: "Add an agent to the config and verify SSH", - Args: cli.ExactArgs(2), - RunE: func(cmd *cli.Command, args []string) error { - name := args[0] - host := args[1] - - forgejoUser, _ := cmd.Flags().GetString("forgejo-user") - if forgejoUser == "" { - forgejoUser = name - } - queueDir, _ := cmd.Flags().GetString("queue-dir") - if queueDir == "" { - queueDir = "" // resolved by orchestrator config - } - model, _ := cmd.Flags().GetString("model") - dualRun, _ := cmd.Flags().GetBool("dual-run") - - // Scan and add host key to known_hosts. - parts := strings.Split(host, "@") - hostname := parts[len(parts)-1] - - fmt.Printf("Scanning host key for %s... ", hostname) - scanCmd := exec.Command("ssh-keyscan", "-H", hostname) - keys, err := scanCmd.Output() - if err != nil { - fmt.Println(errorStyle.Render("FAILED")) - return coreerr.E("agent.add", "failed to scan host keys", err) - } - - home, _ := os.UserHomeDir() - knownHostsPath := filepath.Join(home, ".ssh", "known_hosts") - f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { - return coreerr.E("agent.add", "failed to open known_hosts", err) - } - if _, err := f.Write(keys); err != nil { - f.Close() - return coreerr.E("agent.add", "failed to write known_hosts", err) - } - f.Close() - fmt.Println(successStyle.Render("OK")) - - // Test SSH with strict host key checking. - fmt.Printf("Testing SSH to %s... ", host) - testCmd := agentci.SecureSSHCommand(host, "echo ok") - out, err := testCmd.CombinedOutput() - if err != nil { - fmt.Println(errorStyle.Render("FAILED")) - return coreerr.E("agent.add", "SSH failed: "+strings.TrimSpace(string(out)), nil) - } - fmt.Println(successStyle.Render("OK")) - - cfg, err := loadConfig() - if err != nil { - return err - } - - ac := agentci.AgentConfig{ - Host: host, - QueueDir: queueDir, - ForgejoUser: forgejoUser, - Model: model, - DualRun: dualRun, - Active: true, - } - if err := agentci.SaveAgent(cfg, name, ac); err != nil { - return err - } - - fmt.Printf("Agent %s added (%s)\n", successStyle.Render(name), host) - return nil - }, - } - cmd.Flags().String("forgejo-user", "", "Forgejo username (defaults to agent name)") - cmd.Flags().String("queue-dir", "", "Remote queue directory (default: ~/.core/queue)") - cmd.Flags().String("model", "sonnet", "Primary AI model") - cmd.Flags().Bool("dual-run", false, "Enable Clotho dual-run verification") - return cmd -} - -func agentListCmd() *cli.Command { - return &cli.Command{ - Use: "list", - Short: "List configured agents", - RunE: func(cmd *cli.Command, args []string) error { - cfg, err := loadConfig() - if err != nil { - return err - } - - agents, err := agentci.ListAgents(cfg) - if err != nil { - return err - } - - if len(agents) == 0 { - fmt.Println(dimStyle.Render("No agents configured. Use 'core ai agent add' to add one.")) - return nil - } - - table := cli.NewTable("NAME", "HOST", "MODEL", "DUAL", "ACTIVE", "QUEUE") - for name, ac := range agents { - active := dimStyle.Render("no") - if ac.Active { - active = successStyle.Render("yes") - } - dual := dimStyle.Render("no") - if ac.DualRun { - dual = successStyle.Render("yes") - } - - // Quick SSH check for queue depth. - queue := dimStyle.Render("-") - checkCmd := agentci.SecureSSHCommand(ac.Host, fmt.Sprintf("ls %s/ticket-*.json 2>/dev/null | wc -l", ac.QueueDir)) - out, err := checkCmd.Output() - if err == nil { - n := strings.TrimSpace(string(out)) - if n != "0" { - queue = n - } else { - queue = "0" - } - } - - table.AddRow(name, ac.Host, ac.Model, dual, active, queue) - } - table.Render() - return nil - }, - } -} - -func agentStatusCmd() *cli.Command { - return &cli.Command{ - Use: "status ", - Short: "Check agent status via SSH", - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - name := args[0] - cfg, err := loadConfig() - if err != nil { - return err - } - - agents, err := agentci.ListAgents(cfg) - if err != nil { - return err - } - ac, ok := agents[name] - if !ok { - return coreerr.E("agent.status", "agent not found: "+name, nil) - } - - script := ` - echo "=== Queue ===" - ls ~/ai-work/queue/ticket-*.json 2>/dev/null | wc -l - echo "=== Active ===" - ls ~/ai-work/active/ticket-*.json 2>/dev/null || echo "none" - echo "=== Done ===" - ls ~/ai-work/done/ticket-*.json 2>/dev/null | wc -l - echo "=== Lock ===" - if [ -f ~/ai-work/.runner.lock ]; then - PID=$(cat ~/ai-work/.runner.lock) - if kill -0 "$PID" 2>/dev/null; then - echo "RUNNING (PID $PID)" - else - echo "STALE (PID $PID)" - fi - else - echo "IDLE" - fi - ` - - sshCmd := agentci.SecureSSHCommand(ac.Host, script) - sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr - return sshCmd.Run() - }, - } -} - -func agentLogsCmd() *cli.Command { - cmd := &cli.Command{ - Use: "logs ", - Short: "Stream agent runner logs", - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - name := args[0] - follow, _ := cmd.Flags().GetBool("follow") - lines, _ := cmd.Flags().GetInt("lines") - - cfg, err := loadConfig() - if err != nil { - return err - } - - agents, err := agentci.ListAgents(cfg) - if err != nil { - return err - } - ac, ok := agents[name] - if !ok { - return coreerr.E("agent.status", "agent not found: "+name, nil) - } - - remoteCmd := fmt.Sprintf("tail -n %d ~/ai-work/logs/runner.log", lines) - if follow { - remoteCmd = fmt.Sprintf("tail -f -n %d ~/ai-work/logs/runner.log", lines) - } - - sshCmd := agentci.SecureSSHCommand(ac.Host, remoteCmd) - sshCmd.Stdout = os.Stdout - sshCmd.Stderr = os.Stderr - sshCmd.Stdin = os.Stdin - return sshCmd.Run() - }, - } - cmd.Flags().BoolP("follow", "f", false, "Follow log output") - cmd.Flags().IntP("lines", "n", 50, "Number of lines to show") - return cmd -} - -func agentSetupCmd() *cli.Command { - return &cli.Command{ - Use: "setup ", - Short: "Bootstrap agent machine (create dirs, copy runner, install cron)", - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - name := args[0] - cfg, err := loadConfig() - if err != nil { - return err - } - - agents, err := agentci.ListAgents(cfg) - if err != nil { - return err - } - ac, ok := agents[name] - if !ok { - return coreerr.E("agent.setup", "agent not found: "+name+" — use 'core ai agent add' first", nil) - } - - // Find the setup script relative to the binary or in known locations. - scriptPath := findSetupScript() - if scriptPath == "" { - return coreerr.E("agent.setup", "agent-setup.sh not found — expected in scripts/ directory", nil) - } - - fmt.Printf("Setting up %s on %s...\n", name, ac.Host) - setupCmd := exec.Command("bash", scriptPath, ac.Host) - setupCmd.Stdout = os.Stdout - setupCmd.Stderr = os.Stderr - if err := setupCmd.Run(); err != nil { - return coreerr.E("agent.setup", "setup failed", err) - } - - fmt.Println(successStyle.Render("Setup complete!")) - return nil - }, - } -} - -func agentRemoveCmd() *cli.Command { - return &cli.Command{ - Use: "remove ", - Short: "Remove an agent from config", - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - name := args[0] - cfg, err := loadConfig() - if err != nil { - return err - } - - if err := agentci.RemoveAgent(cfg, name); err != nil { - return err - } - - fmt.Printf("Agent %s removed.\n", name) - return nil - }, - } -} - -func agentFleetCmd() *cli.Command { - cmd := &cli.Command{ - Use: "fleet", - Short: "Show fleet status from the go-agentic registry", - RunE: func(cmd *cli.Command, args []string) error { - workDir, _ := cmd.Flags().GetString("work-dir") - if workDir == "" { - home, _ := os.UserHomeDir() - workDir = filepath.Join(home, defaultWorkDir) - } - dbPath := filepath.Join(workDir, "registry.db") - - if _, err := os.Stat(dbPath); os.IsNotExist(err) { - fmt.Println(dimStyle.Render("No registry found. Start a dispatch watcher first: core ai dispatch watch")) - return nil - } - - registry, err := agentic.NewSQLiteRegistry(dbPath) - if err != nil { - return coreerr.E("agent.fleet", "failed to open registry", err) - } - defer registry.Close() - - // Reap stale agents (no heartbeat for 10 minutes). - reaped := registry.Reap(10 * time.Minute) - if len(reaped) > 0 { - for _, id := range reaped { - fmt.Printf(" Reaped stale agent: %s\n", dimStyle.Render(id)) - } - fmt.Println() - } - - agents := registry.List() - if len(agents) == 0 { - fmt.Println(dimStyle.Render("No agents registered.")) - return nil - } - - table := cli.NewTable("ID", "STATUS", "LOAD", "LAST HEARTBEAT", "CAPABILITIES") - for _, a := range agents { - status := dimStyle.Render(string(a.Status)) - switch a.Status { - case agentic.AgentAvailable: - status = successStyle.Render("available") - case agentic.AgentBusy: - status = taskPriorityMediumStyle.Render("busy") - case agentic.AgentOffline: - status = errorStyle.Render("offline") - } - - load := fmt.Sprintf("%d/%d", a.CurrentLoad, a.MaxLoad) - hb := a.LastHeartbeat.Format("15:04:05") - ago := time.Since(a.LastHeartbeat).Truncate(time.Second) - hbStr := fmt.Sprintf("%s (%s ago)", hb, ago) - - caps := "-" - if len(a.Capabilities) > 0 { - caps = strings.Join(a.Capabilities, ", ") - } - - table.AddRow(a.ID, status, load, hbStr, caps) - } - table.Render() - return nil - }, - } - cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") - return cmd -} - -// findSetupScript looks for agent-setup.sh in common locations. -func findSetupScript() string { - exe, _ := os.Executable() - if exe != "" { - dir := filepath.Dir(exe) - candidates := []string{ - filepath.Join(dir, "scripts", "agent-setup.sh"), - filepath.Join(dir, "..", "scripts", "agent-setup.sh"), - } - for _, c := range candidates { - if _, err := os.Stat(c); err == nil { - return c - } - } - } - - cwd, _ := os.Getwd() - if cwd != "" { - p := filepath.Join(cwd, "scripts", "agent-setup.sh") - if _, err := os.Stat(p); err == nil { - return p - } - } - - return "" -} diff --git a/cmd/dispatch/cmd.go b/cmd/dispatch/cmd.go deleted file mode 100644 index 2f15942..0000000 --- a/cmd/dispatch/cmd.go +++ /dev/null @@ -1,877 +0,0 @@ -package dispatch - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "os" - "os/exec" - "os/signal" - "path/filepath" - "slices" - "strconv" - "strings" - "syscall" - "time" - - "forge.lthn.ai/core/cli/pkg/cli" - coreio "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-log" - - agentic "forge.lthn.ai/core/agent/pkg/lifecycle" -) - -func init() { - cli.RegisterCommands(AddDispatchCommands) -} - -// AddDispatchCommands registers the 'dispatch' subcommand group under 'ai'. -// These commands run ON the agent machine to process the work queue. -func AddDispatchCommands(parent *cli.Command) { - dispatchCmd := &cli.Command{ - Use: "dispatch", - Short: "Agent work queue processor (runs on agent machine)", - } - - dispatchCmd.AddCommand(dispatchRunCmd()) - dispatchCmd.AddCommand(dispatchWatchCmd()) - dispatchCmd.AddCommand(dispatchStatusCmd()) - - parent.AddCommand(dispatchCmd) -} - -// dispatchTicket represents the work item JSON structure. -type dispatchTicket struct { - ID string `json:"id"` - RepoOwner string `json:"repo_owner"` - RepoName string `json:"repo_name"` - IssueNumber int `json:"issue_number"` - IssueTitle string `json:"issue_title"` - IssueBody string `json:"issue_body"` - TargetBranch string `json:"target_branch"` - EpicNumber int `json:"epic_number"` - ForgeURL string `json:"forge_url"` - ForgeToken string `json:"forge_token"` - ForgeUser string `json:"forgejo_user"` - Model string `json:"model"` - Runner string `json:"runner"` - Timeout string `json:"timeout"` - CreatedAt string `json:"created_at"` -} - -const ( - defaultWorkDir = "ai-work" - lockFileName = ".runner.lock" -) - -type runnerPaths struct { - root string - queue string - active string - done string - logs string - jobs string - lock string -} - -func getPaths(baseDir string) runnerPaths { - if baseDir == "" { - home, _ := os.UserHomeDir() - baseDir = filepath.Join(home, defaultWorkDir) - } - return runnerPaths{ - root: baseDir, - queue: filepath.Join(baseDir, "queue"), - active: filepath.Join(baseDir, "active"), - done: filepath.Join(baseDir, "done"), - logs: filepath.Join(baseDir, "logs"), - jobs: filepath.Join(baseDir, "jobs"), - lock: filepath.Join(baseDir, lockFileName), - } -} - -func dispatchRunCmd() *cli.Command { - cmd := &cli.Command{ - Use: "run", - Short: "Process a single ticket from the queue", - RunE: func(cmd *cli.Command, args []string) error { - workDir, _ := cmd.Flags().GetString("work-dir") - paths := getPaths(workDir) - - if err := ensureDispatchDirs(paths); err != nil { - return err - } - - if err := acquireLock(paths.lock); err != nil { - log.Info("Runner locked, skipping run", "lock", paths.lock) - return nil - } - defer releaseLock(paths.lock) - - ticketFile, err := pickOldestTicket(paths.queue) - if err != nil { - return err - } - if ticketFile == "" { - return nil - } - - _, err = processTicket(paths, ticketFile) - return err - }, - } - cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") - return cmd -} - -// fastFailThreshold is how quickly a job must fail to be considered rate-limited. -// Real work always takes longer than 30 seconds; a 3-second exit means the CLI -// was rejected before it could start (rate limit, auth error, etc.). -const fastFailThreshold = 30 * time.Second - -// maxBackoffMultiplier caps the exponential backoff at 8x the base interval. -const maxBackoffMultiplier = 8 - -func dispatchWatchCmd() *cli.Command { - cmd := &cli.Command{ - Use: "watch", - Short: "Poll the PHP agentic API for work", - RunE: func(cmd *cli.Command, args []string) error { - workDir, _ := cmd.Flags().GetString("work-dir") - interval, _ := cmd.Flags().GetDuration("interval") - agentID, _ := cmd.Flags().GetString("agent-id") - agentType, _ := cmd.Flags().GetString("agent-type") - apiURL, _ := cmd.Flags().GetString("api-url") - apiKey, _ := cmd.Flags().GetString("api-key") - - paths := getPaths(workDir) - if err := ensureDispatchDirs(paths); err != nil { - return err - } - - // Create the go-agentic API client. - client := agentic.NewClient(apiURL, apiKey) - client.AgentID = agentID - - // Verify connectivity. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - if err := client.Ping(ctx); err != nil { - return log.E("dispatch.watch", "API ping failed (url="+apiURL+")", err) - } - log.Info("Connected to agentic API", "url", apiURL, "agent", agentID) - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - - // Backoff state. - backoffMultiplier := 1 - currentInterval := interval - - ticker := time.NewTicker(currentInterval) - defer ticker.Stop() - - adjustTicker := func(fastFail bool) { - if fastFail { - if backoffMultiplier < maxBackoffMultiplier { - backoffMultiplier *= 2 - } - currentInterval = interval * time.Duration(backoffMultiplier) - log.Warn("Fast failure detected, backing off", - "multiplier", backoffMultiplier, "next_poll", currentInterval) - } else { - if backoffMultiplier > 1 { - log.Info("Job succeeded, resetting backoff") - } - backoffMultiplier = 1 - currentInterval = interval - } - ticker.Reset(currentInterval) - } - - log.Info("Starting API poller", "interval", interval, "agent", agentID, "type", agentType) - - // Initial poll. - ff := pollAndExecute(ctx, client, agentID, agentType, paths) - adjustTicker(ff) - - for { - select { - case <-ticker.C: - ff := pollAndExecute(ctx, client, agentID, agentType, paths) - adjustTicker(ff) - case <-sigChan: - log.Info("Shutting down watcher...") - return nil - case <-ctx.Done(): - return nil - } - } - }, - } - cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") - cmd.Flags().Duration("interval", 2*time.Minute, "Polling interval") - cmd.Flags().String("agent-id", defaultAgentID(), "Agent identifier") - cmd.Flags().String("agent-type", "opus", "Agent type (opus, sonnet, gemini)") - cmd.Flags().String("api-url", "https://api.lthn.sh", "Agentic API base URL") - cmd.Flags().String("api-key", os.Getenv("AGENTIC_API_KEY"), "Agentic API key") - return cmd -} - -// pollAndExecute checks the API for workable plans and executes one phase per cycle. -// Returns true if a fast failure occurred (signals backoff). -func pollAndExecute(ctx context.Context, client *agentic.Client, agentID, agentType string, paths runnerPaths) bool { - // List active plans. - plans, err := client.ListPlans(ctx, agentic.ListPlanOptions{Status: agentic.PlanActive}) - if err != nil { - log.Error("Failed to list plans", "error", err) - return false - } - if len(plans) == 0 { - log.Debug("No active plans") - return false - } - - // Find the first workable phase across all plans. - for _, plan := range plans { - // Fetch full plan with phases. - fullPlan, err := client.GetPlan(ctx, plan.Slug) - if err != nil { - log.Error("Failed to get plan", "slug", plan.Slug, "error", err) - continue - } - - // Find first workable phase. - var targetPhase *agentic.Phase - for i := range fullPlan.Phases { - p := &fullPlan.Phases[i] - switch p.Status { - case agentic.PhaseInProgress: - targetPhase = p - case agentic.PhasePending: - if p.CanStart { - targetPhase = p - } - } - if targetPhase != nil { - break - } - } - if targetPhase == nil { - continue - } - - log.Info("Found workable phase", - "plan", fullPlan.Slug, "phase", targetPhase.Name, "status", targetPhase.Status) - - // Start session. - session, err := client.StartSession(ctx, agentic.StartSessionRequest{ - AgentType: agentType, - PlanSlug: fullPlan.Slug, - Context: map[string]any{ - "agent_id": agentID, - "phase": targetPhase.Name, - }, - }) - if err != nil { - log.Error("Failed to start session", "error", err) - return false - } - log.Info("Session started", "session_id", session.SessionID) - - // Mark phase in-progress if pending. - if targetPhase.Status == agentic.PhasePending { - if err := client.UpdatePhaseStatus(ctx, fullPlan.Slug, targetPhase.Name, agentic.PhaseInProgress, ""); err != nil { - log.Warn("Failed to mark phase in-progress", "error", err) - } - } - - // Extract repo info from plan metadata. - fastFail := executePhaseWork(ctx, client, fullPlan, targetPhase, session.SessionID, paths) - return fastFail - } - - log.Debug("No workable phases found across active plans") - return false -} - -// executePhaseWork does the actual repo prep + agent run for a phase. -// Returns true if the execution was a fast failure. -func executePhaseWork(ctx context.Context, client *agentic.Client, plan *agentic.Plan, phase *agentic.Phase, sessionID string, paths runnerPaths) bool { - // Extract repo metadata from the plan. - meta, _ := plan.Metadata.(map[string]any) - repoOwner, _ := meta["repo_owner"].(string) - repoName, _ := meta["repo_name"].(string) - issueNumFloat, _ := meta["issue_number"].(float64) // JSON numbers are float64 - issueNumber := int(issueNumFloat) - - forgeURL, _ := meta["forge_url"].(string) - forgeToken, _ := meta["forge_token"].(string) - forgeUser, _ := meta["forgejo_user"].(string) - targetBranch, _ := meta["target_branch"].(string) - runner, _ := meta["runner"].(string) - model, _ := meta["model"].(string) - timeout, _ := meta["timeout"].(string) - - if targetBranch == "" { - targetBranch = "main" - } - if runner == "" { - runner = "claude" - } - - // Build a dispatchTicket from the metadata so existing functions work. - t := dispatchTicket{ - ID: fmt.Sprintf("%s-%s", plan.Slug, phase.Name), - RepoOwner: repoOwner, - RepoName: repoName, - IssueNumber: issueNumber, - IssueTitle: plan.Title, - IssueBody: phase.Description, - TargetBranch: targetBranch, - ForgeURL: forgeURL, - ForgeToken: forgeToken, - ForgeUser: forgeUser, - Model: model, - Runner: runner, - Timeout: timeout, - } - - if t.RepoOwner == "" || t.RepoName == "" { - log.Error("Plan metadata missing repo_owner or repo_name", "plan", plan.Slug) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), "missing repo metadata") - return false - } - - // Prepare the repository. - jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber)) - repoDir := filepath.Join(jobDir, t.RepoName) - if err := coreio.Local.EnsureDir(jobDir); err != nil { - log.Error("Failed to create job dir", "error", err) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), fmt.Sprintf("mkdir failed: %v", err)) - return false - } - - if err := prepareRepo(t, repoDir); err != nil { - log.Error("Repo preparation failed", "error", err) - _ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseBlocked, fmt.Sprintf("git setup failed: %v", err)) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), fmt.Sprintf("repo prep failed: %v", err)) - return false - } - - // Build prompt and run. - prompt := buildPrompt(t) - logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s.log", plan.Slug, phase.Name)) - - start := time.Now() - success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile) - elapsed := time.Since(start) - - // Detect fast failure. - if !success && elapsed < fastFailThreshold { - log.Warn("Agent rejected fast, likely rate-limited", - "elapsed", elapsed.Round(time.Second), "plan", plan.Slug, "phase", phase.Name) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), "fast failure — likely rate-limited") - return true - } - - // Report results. - if success { - _ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseCompleted, - fmt.Sprintf("completed in %s", elapsed.Round(time.Second))) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionCompleted), - fmt.Sprintf("Phase %q completed successfully (exit %d, %s)", phase.Name, exitCode, elapsed.Round(time.Second))) - } else { - note := fmt.Sprintf("failed with exit code %d after %s", exitCode, elapsed.Round(time.Second)) - if runErr != nil { - note += fmt.Sprintf(": %v", runErr) - } - _ = client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, agentic.PhaseBlocked, note) - _ = client.EndSession(ctx, sessionID, string(agentic.SessionFailed), note) - } - - // Also report to Forge issue if configured. - msg := fmt.Sprintf("Agent completed phase %q of plan %q. Exit code: %d.", phase.Name, plan.Slug, exitCode) - if !success { - msg = fmt.Sprintf("Agent failed phase %q of plan %q (exit code: %d).", phase.Name, plan.Slug, exitCode) - } - reportToForge(t, success, msg) - - log.Info("Phase complete", "plan", plan.Slug, "phase", phase.Name, "success", success, "elapsed", elapsed.Round(time.Second)) - return false -} - -// defaultAgentID returns a sensible agent ID from hostname. -func defaultAgentID() string { - host, _ := os.Hostname() - if host == "" { - return "unknown" - } - return host -} - -// --- Legacy registry/heartbeat functions (replaced by PHP API poller) --- - -// registerAgent creates a SQLite registry and registers this agent. -// DEPRECATED: The watch command now uses the PHP agentic API instead. -// Kept for reference; remove once the API poller is proven stable. -/* -func registerAgent(agentID string, paths runnerPaths) (agentic.AgentRegistry, agentic.EventEmitter, func()) { - dbPath := filepath.Join(paths.root, "registry.db") - registry, err := agentic.NewSQLiteRegistry(dbPath) - if err != nil { - log.Warn("Failed to create agent registry", "error", err, "path", dbPath) - return nil, nil, nil - } - - info := agentic.AgentInfo{ - ID: agentID, - Name: agentID, - Status: agentic.AgentAvailable, - LastHeartbeat: time.Now().UTC(), - MaxLoad: 1, - } - if err := registry.Register(info); err != nil { - log.Warn("Failed to register agent", "error", err) - } else { - log.Info("Agent registered", "id", agentID) - } - - events := agentic.NewChannelEmitter(64) - - // Drain events to log. - go func() { - for ev := range events.Events() { - log.Debug("Event", "type", string(ev.Type), "task", ev.TaskID, "agent", ev.AgentID) - } - }() - - return registry, events, func() { - events.Close() - } -} -*/ - -// heartbeatLoop sends periodic heartbeats to keep the agent status fresh. -// DEPRECATED: Replaced by PHP API poller. -/* -func heartbeatLoop(ctx context.Context, registry agentic.AgentRegistry, agentID string, interval time.Duration) { - if interval < 30*time.Second { - interval = 30 * time.Second - } - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - _ = registry.Heartbeat(agentID) - } - } -} -*/ - -// runCycleWithEvents wraps runCycle with registry status updates and event emission. -// DEPRECATED: Replaced by pollAndExecute. -/* -func runCycleWithEvents(paths runnerPaths, registry agentic.AgentRegistry, events agentic.EventEmitter, agentID string) bool { - if registry != nil { - if agent, err := registry.Get(agentID); err == nil { - agent.Status = agentic.AgentBusy - _ = registry.Register(agent) - } - } - - fastFail := runCycle(paths) - - if registry != nil { - if agent, err := registry.Get(agentID); err == nil { - agent.Status = agentic.AgentAvailable - agent.LastHeartbeat = time.Now().UTC() - _ = registry.Register(agent) - } - } - return fastFail -} -*/ - -func dispatchStatusCmd() *cli.Command { - cmd := &cli.Command{ - Use: "status", - Short: "Show runner status", - RunE: func(cmd *cli.Command, args []string) error { - workDir, _ := cmd.Flags().GetString("work-dir") - paths := getPaths(workDir) - - lockStatus := "IDLE" - if data, err := coreio.Local.Read(paths.lock); err == nil { - pidStr := strings.TrimSpace(data) - pid, _ := strconv.Atoi(pidStr) - if isProcessAlive(pid) { - lockStatus = fmt.Sprintf("RUNNING (PID %d)", pid) - } else { - lockStatus = fmt.Sprintf("STALE (PID %d)", pid) - } - } - - countFiles := func(dir string) int { - entries, _ := os.ReadDir(dir) - count := 0 - for _, e := range entries { - if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") { - count++ - } - } - return count - } - - fmt.Println("=== Agent Dispatch Status ===") - fmt.Printf("Work Dir: %s\n", paths.root) - fmt.Printf("Status: %s\n", lockStatus) - fmt.Printf("Queue: %d\n", countFiles(paths.queue)) - fmt.Printf("Active: %d\n", countFiles(paths.active)) - fmt.Printf("Done: %d\n", countFiles(paths.done)) - - return nil - }, - } - cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") - return cmd -} - -// runCycle picks and processes one ticket. Returns true if the job fast-failed -// (likely rate-limited), signalling the caller to back off. -func runCycle(paths runnerPaths) bool { - if err := acquireLock(paths.lock); err != nil { - log.Debug("Runner locked, skipping cycle") - return false - } - defer releaseLock(paths.lock) - - ticketFile, err := pickOldestTicket(paths.queue) - if err != nil { - log.Error("Failed to pick ticket", "error", err) - return false - } - if ticketFile == "" { - return false // empty queue, no backoff needed - } - - start := time.Now() - success, err := processTicket(paths, ticketFile) - elapsed := time.Since(start) - - if err != nil { - log.Error("Failed to process ticket", "file", ticketFile, "error", err) - } - - // Detect fast failure: job failed in under 30s → likely rate-limited. - if !success && elapsed < fastFailThreshold { - log.Warn("Job finished too fast, likely rate-limited", - "elapsed", elapsed.Round(time.Second), "file", filepath.Base(ticketFile)) - return true - } - - return false -} - -// processTicket processes a single ticket. Returns (success, error). -// On fast failure the caller is responsible for detecting the timing and backing off. -// The ticket is moved active→done on completion, or active→queue on fast failure. -func processTicket(paths runnerPaths, ticketPath string) (bool, error) { - fileName := filepath.Base(ticketPath) - log.Info("Processing ticket", "file", fileName) - - activePath := filepath.Join(paths.active, fileName) - if err := os.Rename(ticketPath, activePath); err != nil { - return false, log.E("processTicket", "failed to move ticket to active", err) - } - - data, err := coreio.Local.Read(activePath) - if err != nil { - return false, log.E("processTicket", "failed to read ticket", err) - } - var t dispatchTicket - if err := json.Unmarshal([]byte(data), &t); err != nil { - return false, log.E("processTicket", "failed to unmarshal ticket", err) - } - - jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber)) - repoDir := filepath.Join(jobDir, t.RepoName) - if err := coreio.Local.EnsureDir(jobDir); err != nil { - return false, err - } - - if err := prepareRepo(t, repoDir); err != nil { - reportToForge(t, false, fmt.Sprintf("Git setup failed: %v", err)) - moveToDone(paths, activePath, fileName) - return false, err - } - - prompt := buildPrompt(t) - - logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s-%d.log", t.RepoOwner, t.RepoName, t.IssueNumber)) - start := time.Now() - success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile) - elapsed := time.Since(start) - - // Fast failure: agent exited in <30s without success → likely rate-limited. - // Requeue the ticket so it's retried after the backoff period. - if !success && elapsed < fastFailThreshold { - log.Warn("Agent rejected fast, requeuing ticket", "elapsed", elapsed.Round(time.Second), "file", fileName) - requeuePath := filepath.Join(paths.queue, fileName) - if err := os.Rename(activePath, requeuePath); err != nil { - // Fallback: move to done if requeue fails. - moveToDone(paths, activePath, fileName) - } - return false, runErr - } - - msg := fmt.Sprintf("Agent completed work on #%d. Exit code: %d.", t.IssueNumber, exitCode) - if !success { - msg = fmt.Sprintf("Agent failed on #%d (exit code: %d). Check logs on agent machine.", t.IssueNumber, exitCode) - if runErr != nil { - msg += fmt.Sprintf(" Error: %v", runErr) - } - } - reportToForge(t, success, msg) - - moveToDone(paths, activePath, fileName) - log.Info("Ticket complete", "id", t.ID, "success", success, "elapsed", elapsed.Round(time.Second)) - return success, nil -} - -func prepareRepo(t dispatchTicket, repoDir string) error { - user := t.ForgeUser - if user == "" { - host, _ := os.Hostname() - user = fmt.Sprintf("%s-%s", host, os.Getenv("USER")) - } - - cleanURL := strings.TrimPrefix(t.ForgeURL, "https://") - cleanURL = strings.TrimPrefix(cleanURL, "http://") - cloneURL := fmt.Sprintf("https://%s:%s@%s/%s/%s.git", user, t.ForgeToken, cleanURL, t.RepoOwner, t.RepoName) - - if _, err := os.Stat(filepath.Join(repoDir, ".git")); err == nil { - log.Info("Updating existing repo", "dir", repoDir) - cmds := [][]string{ - {"git", "fetch", "origin"}, - {"git", "checkout", t.TargetBranch}, - {"git", "pull", "origin", t.TargetBranch}, - } - for _, args := range cmds { - cmd := exec.Command(args[0], args[1:]...) - cmd.Dir = repoDir - if out, err := cmd.CombinedOutput(); err != nil { - if args[1] == "checkout" { - createCmd := exec.Command("git", "checkout", "-b", t.TargetBranch, "origin/"+t.TargetBranch) - createCmd.Dir = repoDir - if _, err2 := createCmd.CombinedOutput(); err2 == nil { - continue - } - } - return log.E("prepareRepo", "git command failed: "+string(out), err) - } - } - } else { - log.Info("Cloning repo", "url", t.RepoOwner+"/"+t.RepoName) - cmd := exec.Command("git", "clone", "-b", t.TargetBranch, cloneURL, repoDir) - if out, err := cmd.CombinedOutput(); err != nil { - return log.E("prepareRepo", "git clone failed: "+string(out), err) - } - } - return nil -} - -func buildPrompt(t dispatchTicket) string { - return fmt.Sprintf(`You are working on issue #%d in %s/%s. - -Title: %s - -Description: -%s - -The repo is cloned at the current directory on branch '%s'. -Create a feature branch from '%s', make minimal targeted changes, commit referencing #%d, and push. -Then create a PR targeting '%s' using the forgejo MCP tools or git push.`, - t.IssueNumber, t.RepoOwner, t.RepoName, - t.IssueTitle, - t.IssueBody, - t.TargetBranch, - t.TargetBranch, t.IssueNumber, - t.TargetBranch, - ) -} - -func runAgent(t dispatchTicket, prompt, dir, logPath string) (bool, int, error) { - timeout := 30 * time.Minute - if t.Timeout != "" { - if d, err := time.ParseDuration(t.Timeout); err == nil { - timeout = d - } - } - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - model := t.Model - if model == "" { - model = "sonnet" - } - - log.Info("Running agent", "runner", t.Runner, "model", model) - - // For Gemini runner, wrap with rate limiting. - if t.Runner == "gemini" { - return executeWithRateLimit(ctx, model, prompt, func() (bool, int, error) { - return execAgent(ctx, t.Runner, model, prompt, dir, logPath) - }) - } - - return execAgent(ctx, t.Runner, model, prompt, dir, logPath) -} - -func execAgent(ctx context.Context, runner, model, prompt, dir, logPath string) (bool, int, error) { - var cmd *exec.Cmd - - switch runner { - case "codex": - cmd = exec.CommandContext(ctx, "codex", "exec", "--full-auto", prompt) - case "gemini": - args := []string{"-p", "-", "-y", "-m", model} - cmd = exec.CommandContext(ctx, "gemini", args...) - cmd.Stdin = strings.NewReader(prompt) - default: // claude - cmd = exec.CommandContext(ctx, "claude", "-p", "--model", model, "--dangerously-skip-permissions", "--output-format", "text") - cmd.Stdin = strings.NewReader(prompt) - } - - cmd.Dir = dir - - f, err := os.Create(logPath) - if err != nil { - return false, -1, err - } - defer f.Close() - - cmd.Stdout = f - cmd.Stderr = f - - if err := cmd.Run(); err != nil { - exitCode := -1 - if exitErr, ok := err.(*exec.ExitError); ok { - exitCode = exitErr.ExitCode() - } - return false, exitCode, err - } - - return true, 0, nil -} - -func reportToForge(t dispatchTicket, success bool, body string) { - token := t.ForgeToken - if token == "" { - token = os.Getenv("FORGE_TOKEN") - } - if token == "" { - log.Warn("No forge token available, skipping report") - return - } - - url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments", - strings.TrimSuffix(t.ForgeURL, "/"), t.RepoOwner, t.RepoName, t.IssueNumber) - - payload := map[string]string{"body": body} - jsonBody, _ := json.Marshal(payload) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) - if err != nil { - log.Error("Failed to create request", "err", err) - return - } - req.Header.Set("Authorization", "token "+token) - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - log.Error("Failed to report to Forge", "err", err) - return - } - defer resp.Body.Close() - - if resp.StatusCode >= 300 { - log.Warn("Forge reported error", "status", resp.Status) - } -} - -func moveToDone(paths runnerPaths, activePath, fileName string) { - donePath := filepath.Join(paths.done, fileName) - if err := os.Rename(activePath, donePath); err != nil { - log.Error("Failed to move ticket to done", "err", err) - } -} - -func ensureDispatchDirs(p runnerPaths) error { - dirs := []string{p.queue, p.active, p.done, p.logs, p.jobs} - for _, d := range dirs { - if err := coreio.Local.EnsureDir(d); err != nil { - return log.E("ensureDispatchDirs", "mkdir "+d+" failed", err) - } - } - return nil -} - -func acquireLock(lockPath string) error { - if data, err := coreio.Local.Read(lockPath); err == nil { - pidStr := strings.TrimSpace(data) - pid, _ := strconv.Atoi(pidStr) - if isProcessAlive(pid) { - return log.E("acquireLock", fmt.Sprintf("locked by PID %d", pid), nil) - } - log.Info("Removing stale lock", "pid", pid) - _ = coreio.Local.Delete(lockPath) - } - - return coreio.Local.Write(lockPath, fmt.Sprintf("%d", os.Getpid())) -} - -func releaseLock(lockPath string) { - _ = coreio.Local.Delete(lockPath) -} - -func isProcessAlive(pid int) bool { - if pid <= 0 { - return false - } - process, err := os.FindProcess(pid) - if err != nil { - return false - } - return process.Signal(syscall.Signal(0)) == nil -} - -func pickOldestTicket(queueDir string) (string, error) { - entries, err := os.ReadDir(queueDir) - if err != nil { - return "", err - } - - var tickets []string - for _, e := range entries { - if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") && strings.HasSuffix(e.Name(), ".json") { - tickets = append(tickets, filepath.Join(queueDir, e.Name())) - } - } - - if len(tickets) == 0 { - return "", nil - } - - slices.Sort(tickets) - return tickets[0], nil -} diff --git a/cmd/dispatch/ratelimit.go b/cmd/dispatch/ratelimit.go deleted file mode 100644 index 0eabcc4..0000000 --- a/cmd/dispatch/ratelimit.go +++ /dev/null @@ -1,46 +0,0 @@ -package dispatch - -import ( - "context" - - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-ratelimit" -) - -// executeWithRateLimit wraps an agent execution with rate limiting logic. -// It estimates token usage, waits for capacity, executes the runner, and records usage. -func executeWithRateLimit(ctx context.Context, model, prompt string, runner func() (bool, int, error)) (bool, int, error) { - rl, err := ratelimit.New() - if err != nil { - log.Warn("Failed to initialize rate limiter, proceeding without limits", "error", err) - return runner() - } - - if err := rl.Load(); err != nil { - log.Warn("Failed to load rate limit state", "error", err) - } - - // Estimate tokens from prompt length (1 token ≈ 4 chars) - estTokens := len(prompt) / 4 - if estTokens == 0 { - estTokens = 1 - } - - log.Info("Checking rate limits", "model", model, "est_tokens", estTokens) - - if err := rl.WaitForCapacity(ctx, model, estTokens); err != nil { - return false, -1, err - } - - success, exitCode, runErr := runner() - - // Record usage with conservative output estimate (actual tokens unknown from shell runner). - outputEst := max(estTokens/10, 50) - rl.RecordUsage(model, estTokens, outputEst) - - if err := rl.Persist(); err != nil { - log.Warn("Failed to persist rate limit state", "error", err) - } - - return success, exitCode, runErr -} diff --git a/cmd/taskgit/cmd.go b/cmd/taskgit/cmd.go deleted file mode 100644 index 9354569..0000000 --- a/cmd/taskgit/cmd.go +++ /dev/null @@ -1,256 +0,0 @@ -// Package taskgit implements git integration commands for task commits and PRs. - -package taskgit - -import ( - "bytes" - "context" - "os" - "os/exec" - "strings" - "time" - - agentic "forge.lthn.ai/core/agent/pkg/lifecycle" - "forge.lthn.ai/core/cli/pkg/cli" - "forge.lthn.ai/core/go-i18n" -) - -func init() { - cli.RegisterCommands(AddTaskGitCommands) -} - -// Style aliases from shared package. -var ( - successStyle = cli.SuccessStyle - dimStyle = cli.DimStyle -) - -// task:commit command flags -var ( - taskCommitMessage string - taskCommitScope string - taskCommitPush bool -) - -// task:pr command flags -var ( - taskPRTitle string - taskPRDraft bool - taskPRLabels string - taskPRBase string -) - -var taskCommitCmd = &cli.Command{ - Use: "task:commit [task-id]", - Short: i18n.T("cmd.ai.task_commit.short"), - Long: i18n.T("cmd.ai.task_commit.long"), - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - taskID := args[0] - - if taskCommitMessage == "" { - return cli.Err("commit message required") - } - - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Get task details - task, err := client.GetTask(ctx, taskID) - if err != nil { - return cli.WrapVerb(err, "get", "task") - } - - // Build commit message with optional scope - commitType := inferCommitType(task.Labels) - var fullMessage string - if taskCommitScope != "" { - fullMessage = cli.Sprintf("%s(%s): %s", commitType, taskCommitScope, taskCommitMessage) - } else { - fullMessage = cli.Sprintf("%s: %s", commitType, taskCommitMessage) - } - - // Get current directory - cwd, err := os.Getwd() - if err != nil { - return cli.WrapVerb(err, "get", "working directory") - } - - // Check for uncommitted changes - hasChanges, err := agentic.HasUncommittedChanges(ctx, cwd) - if err != nil { - return cli.WrapVerb(err, "check", "git status") - } - - if !hasChanges { - cli.Println("No changes to commit") - return nil - } - - // Create commit - cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("create", "commit for "+taskID)) - if err := agentic.AutoCommit(ctx, task, cwd, fullMessage); err != nil { - return cli.WrapAction(err, "commit") - } - - cli.Print("%s %s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.commit")+":", fullMessage) - - // Push if requested - if taskCommitPush { - cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.Progress("push")) - if err := agentic.PushChanges(ctx, cwd); err != nil { - return cli.WrapAction(err, "push") - } - cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.push", "changes")) - } - - return nil - }, -} - -var taskPRCmd = &cli.Command{ - Use: "task:pr [task-id]", - Short: i18n.T("cmd.ai.task_pr.short"), - Long: i18n.T("cmd.ai.task_pr.long"), - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - taskID := args[0] - - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - // Get task details - task, err := client.GetTask(ctx, taskID) - if err != nil { - return cli.WrapVerb(err, "get", "task") - } - - // Get current directory - cwd, err := os.Getwd() - if err != nil { - return cli.WrapVerb(err, "get", "working directory") - } - - // Check current branch - branch, err := agentic.GetCurrentBranch(ctx, cwd) - if err != nil { - return cli.WrapVerb(err, "get", "branch") - } - - if branch == "main" || branch == "master" { - return cli.Err("cannot create PR from %s branch", branch) - } - - // Push current branch - cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("push", branch)) - if err := agentic.PushChanges(ctx, cwd); err != nil { - // Try setting upstream - if _, err := runGitCommand(cwd, "push", "-u", "origin", branch); err != nil { - return cli.WrapVerb(err, "push", "branch") - } - } - - // Build PR options - opts := agentic.PROptions{ - Title: taskPRTitle, - Draft: taskPRDraft, - Base: taskPRBase, - } - - if taskPRLabels != "" { - opts.Labels = strings.Split(taskPRLabels, ",") - } - - // Create PR - cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.ProgressSubject("create", "PR")) - prURL, err := agentic.CreatePR(ctx, task, cwd, opts) - if err != nil { - return cli.WrapVerb(err, "create", "PR") - } - - cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.create", "PR")) - cli.Print(" %s %s\n", i18n.Label("url"), prURL) - - return nil - }, -} - -func initGitFlags() { - // task:commit command flags - taskCommitCmd.Flags().StringVarP(&taskCommitMessage, "message", "m", "", i18n.T("cmd.ai.task_commit.flag.message")) - taskCommitCmd.Flags().StringVar(&taskCommitScope, "scope", "", i18n.T("cmd.ai.task_commit.flag.scope")) - taskCommitCmd.Flags().BoolVar(&taskCommitPush, "push", false, i18n.T("cmd.ai.task_commit.flag.push")) - - // task:pr command flags - taskPRCmd.Flags().StringVar(&taskPRTitle, "title", "", i18n.T("cmd.ai.task_pr.flag.title")) - taskPRCmd.Flags().BoolVar(&taskPRDraft, "draft", false, i18n.T("cmd.ai.task_pr.flag.draft")) - taskPRCmd.Flags().StringVar(&taskPRLabels, "labels", "", i18n.T("cmd.ai.task_pr.flag.labels")) - taskPRCmd.Flags().StringVar(&taskPRBase, "base", "", i18n.T("cmd.ai.task_pr.flag.base")) -} - -// AddTaskGitCommands registers the task:commit and task:pr commands under a parent. -func AddTaskGitCommands(parent *cli.Command) { - initGitFlags() - parent.AddCommand(taskCommitCmd) - parent.AddCommand(taskPRCmd) -} - -// inferCommitType infers the commit type from task labels. -func inferCommitType(labels []string) string { - for _, label := range labels { - switch strings.ToLower(label) { - case "bug", "bugfix", "fix": - return "fix" - case "docs", "documentation": - return "docs" - case "refactor", "refactoring": - return "refactor" - case "test", "tests", "testing": - return "test" - case "chore": - return "chore" - case "style": - return "style" - case "perf", "performance": - return "perf" - case "ci": - return "ci" - case "build": - return "build" - } - } - return "feat" -} - -// runGitCommand runs a git command in the specified directory. -func runGitCommand(dir string, args ...string) (string, error) { - cmd := exec.Command("git", args...) - cmd.Dir = dir - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if stderr.Len() > 0 { - return "", cli.Wrap(err, stderr.String()) - } - return "", err - } - - return stdout.String(), nil -} diff --git a/cmd/tasks/cmd.go b/cmd/tasks/cmd.go deleted file mode 100644 index 7e0c742..0000000 --- a/cmd/tasks/cmd.go +++ /dev/null @@ -1,328 +0,0 @@ -// Package tasks implements task listing, viewing, and claiming commands. - -package tasks - -import ( - "context" - "os" - "slices" - "strings" - "time" - - "forge.lthn.ai/core/cli/pkg/cli" - agentic "forge.lthn.ai/core/agent/pkg/lifecycle" - "forge.lthn.ai/core/go-ai/ai" - "forge.lthn.ai/core/go-i18n" -) - -// Style aliases from shared package -var ( - successStyle = cli.SuccessStyle - errorStyle = cli.ErrorStyle - dimStyle = cli.DimStyle - truncate = cli.Truncate - formatAge = cli.FormatAge -) - -// Task priority/status styles from shared -var ( - taskPriorityHighStyle = cli.NewStyle().Foreground(cli.ColourRed500) - taskPriorityMediumStyle = cli.NewStyle().Foreground(cli.ColourAmber500) - taskPriorityLowStyle = cli.NewStyle().Foreground(cli.ColourBlue400) - taskStatusPendingStyle = cli.DimStyle - taskStatusInProgressStyle = cli.NewStyle().Foreground(cli.ColourBlue500) - taskStatusCompletedStyle = cli.SuccessStyle - taskStatusBlockedStyle = cli.ErrorStyle -) - -// Task-specific styles (aliases to shared where possible) -var ( - taskIDStyle = cli.TitleStyle // Bold + blue - taskTitleStyle = cli.ValueStyle // Light gray - taskLabelStyle = cli.NewStyle().Foreground(cli.ColourViolet500) // Violet for labels -) - -// tasks command flags -var ( - tasksStatus string - tasksPriority string - tasksLabels string - tasksLimit int - tasksProject string -) - -// task command flags -var ( - taskAutoSelect bool - taskClaim bool - taskShowContext bool -) - -var tasksCmd = &cli.Command{ - Use: "tasks", - Short: i18n.T("cmd.ai.tasks.short"), - Long: i18n.T("cmd.ai.tasks.long"), - RunE: func(cmd *cli.Command, args []string) error { - limit := tasksLimit - if limit == 0 { - limit = 20 - } - - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - opts := agentic.ListOptions{ - Limit: limit, - Project: tasksProject, - } - - if tasksStatus != "" { - opts.Status = agentic.TaskStatus(tasksStatus) - } - if tasksPriority != "" { - opts.Priority = agentic.TaskPriority(tasksPriority) - } - if tasksLabels != "" { - opts.Labels = strings.Split(tasksLabels, ",") - } - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - tasks, err := client.ListTasks(ctx, opts) - if err != nil { - return cli.WrapVerb(err, "list", "tasks") - } - - if len(tasks) == 0 { - cli.Text(i18n.T("cmd.ai.tasks.none_found")) - return nil - } - - printTaskList(tasks) - return nil - }, -} - -var taskCmd = &cli.Command{ - Use: "task [task-id]", - Short: i18n.T("cmd.ai.task.short"), - Long: i18n.T("cmd.ai.task.long"), - RunE: func(cmd *cli.Command, args []string) error { - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - var task *agentic.Task - - // Get the task ID from args - var taskID string - if len(args) > 0 { - taskID = args[0] - } - - if taskAutoSelect { - // Auto-select: find highest priority pending task - tasks, err := client.ListTasks(ctx, agentic.ListOptions{ - Status: agentic.StatusPending, - Limit: 50, - }) - if err != nil { - return cli.WrapVerb(err, "list", "tasks") - } - - if len(tasks) == 0 { - cli.Text(i18n.T("cmd.ai.task.no_pending")) - return nil - } - - // Sort by priority (critical > high > medium > low) - priorityOrder := map[agentic.TaskPriority]int{ - agentic.PriorityCritical: 0, - agentic.PriorityHigh: 1, - agentic.PriorityMedium: 2, - agentic.PriorityLow: 3, - } - - slices.SortFunc(tasks, func(a, b agentic.Task) int { - return priorityOrder[a.Priority] - priorityOrder[b.Priority] - }) - - task = &tasks[0] - taskClaim = true // Auto-select implies claiming - } else { - if taskID == "" { - return cli.Err("%s", i18n.T("cmd.ai.task.id_required")) - } - - task, err = client.GetTask(ctx, taskID) - if err != nil { - return cli.WrapVerb(err, "get", "task") - } - } - - // Show context if requested - if taskShowContext { - cwd, _ := os.Getwd() - taskCtx, err := agentic.BuildTaskContext(task, cwd) - if err != nil { - cli.Print("%s %s: %s\n", errorStyle.Render(">>"), i18n.T("i18n.fail.build", "context"), err) - } else { - cli.Text(taskCtx.FormatContext()) - } - } else { - printTaskDetails(task) - } - - if taskClaim && task.Status == agentic.StatusPending { - cli.Blank() - cli.Print("%s %s\n", dimStyle.Render(">>"), i18n.T("cmd.ai.task.claiming")) - - claimedTask, err := client.ClaimTask(ctx, task.ID) - if err != nil { - return cli.WrapVerb(err, "claim", "task") - } - - // Record task claim event - _ = ai.Record(ai.Event{ - Type: "task.claimed", - AgentID: cfg.AgentID, - Data: map[string]any{"task_id": task.ID, "title": task.Title}, - }) - - cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.claim", "task")) - cli.Print(" %s %s\n", i18n.Label("status"), formatTaskStatus(claimedTask.Status)) - } - - return nil - }, -} - -func initTasksFlags() { - // tasks command flags - tasksCmd.Flags().StringVar(&tasksStatus, "status", "", i18n.T("cmd.ai.tasks.flag.status")) - tasksCmd.Flags().StringVar(&tasksPriority, "priority", "", i18n.T("cmd.ai.tasks.flag.priority")) - tasksCmd.Flags().StringVar(&tasksLabels, "labels", "", i18n.T("cmd.ai.tasks.flag.labels")) - tasksCmd.Flags().IntVar(&tasksLimit, "limit", 20, i18n.T("cmd.ai.tasks.flag.limit")) - tasksCmd.Flags().StringVar(&tasksProject, "project", "", i18n.T("cmd.ai.tasks.flag.project")) - - // task command flags - taskCmd.Flags().BoolVar(&taskAutoSelect, "auto", false, i18n.T("cmd.ai.task.flag.auto")) - taskCmd.Flags().BoolVar(&taskClaim, "claim", false, i18n.T("cmd.ai.task.flag.claim")) - taskCmd.Flags().BoolVar(&taskShowContext, "context", false, i18n.T("cmd.ai.task.flag.context")) -} - -// AddTaskCommands adds the task management commands to a parent command. -func AddTaskCommands(parent *cli.Command) { - // Task listing and viewing - initTasksFlags() - parent.AddCommand(tasksCmd) - parent.AddCommand(taskCmd) - - // Task updates - initUpdatesFlags() - parent.AddCommand(taskUpdateCmd) - parent.AddCommand(taskCompleteCmd) -} - -func printTaskList(tasks []agentic.Task) { - cli.Print("\n%s\n\n", i18n.T("cmd.ai.tasks.found", map[string]any{"Count": len(tasks)})) - - for _, task := range tasks { - id := taskIDStyle.Render(task.ID) - title := taskTitleStyle.Render(truncate(task.Title, 50)) - priority := formatTaskPriority(task.Priority) - status := formatTaskStatus(task.Status) - - line := cli.Sprintf(" %s %s %s %s", id, priority, status, title) - - if len(task.Labels) > 0 { - labels := taskLabelStyle.Render("[" + strings.Join(task.Labels, ", ") + "]") - line += " " + labels - } - - cli.Text(line) - } - - cli.Blank() - cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.tasks.hint"))) -} - -func printTaskDetails(task *agentic.Task) { - cli.Blank() - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.id")), taskIDStyle.Render(task.ID)) - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.title")), taskTitleStyle.Render(task.Title)) - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.priority")), formatTaskPriority(task.Priority)) - cli.Print("%s %s\n", dimStyle.Render(i18n.Label("status")), formatTaskStatus(task.Status)) - - if task.Project != "" { - cli.Print("%s %s\n", dimStyle.Render(i18n.Label("project")), task.Project) - } - - if len(task.Labels) > 0 { - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.labels")), taskLabelStyle.Render(strings.Join(task.Labels, ", "))) - } - - if task.ClaimedBy != "" { - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.claimed_by")), task.ClaimedBy) - } - - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.created")), formatAge(task.CreatedAt)) - - cli.Blank() - cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.label.description"))) - cli.Text(task.Description) - - if len(task.Files) > 0 { - cli.Blank() - cli.Print("%s\n", dimStyle.Render(i18n.T("cmd.ai.label.related_files"))) - for _, f := range task.Files { - cli.Print(" - %s\n", f) - } - } - - if len(task.Dependencies) > 0 { - cli.Blank() - cli.Print("%s %s\n", dimStyle.Render(i18n.T("cmd.ai.label.blocked_by")), strings.Join(task.Dependencies, ", ")) - } -} - -func formatTaskPriority(p agentic.TaskPriority) string { - switch p { - case agentic.PriorityCritical: - return taskPriorityHighStyle.Render("[" + i18n.T("cmd.ai.priority.critical") + "]") - case agentic.PriorityHigh: - return taskPriorityHighStyle.Render("[" + i18n.T("cmd.ai.priority.high") + "]") - case agentic.PriorityMedium: - return taskPriorityMediumStyle.Render("[" + i18n.T("cmd.ai.priority.medium") + "]") - case agentic.PriorityLow: - return taskPriorityLowStyle.Render("[" + i18n.T("cmd.ai.priority.low") + "]") - default: - return dimStyle.Render("[" + string(p) + "]") - } -} - -func formatTaskStatus(s agentic.TaskStatus) string { - switch s { - case agentic.StatusPending: - return taskStatusPendingStyle.Render(i18n.T("cmd.ai.status.pending")) - case agentic.StatusInProgress: - return taskStatusInProgressStyle.Render(i18n.T("cmd.ai.status.in_progress")) - case agentic.StatusCompleted: - return taskStatusCompletedStyle.Render(i18n.T("cmd.ai.status.completed")) - case agentic.StatusBlocked: - return taskStatusBlockedStyle.Render(i18n.T("cmd.ai.status.blocked")) - default: - return dimStyle.Render(string(s)) - } -} diff --git a/cmd/tasks/updates.go b/cmd/tasks/updates.go deleted file mode 100644 index 06047d2..0000000 --- a/cmd/tasks/updates.go +++ /dev/null @@ -1,122 +0,0 @@ -// updates.go implements task update and completion commands. - -package tasks - -import ( - "context" - "time" - - agentic "forge.lthn.ai/core/agent/pkg/lifecycle" - "forge.lthn.ai/core/go-ai/ai" - "forge.lthn.ai/core/cli/pkg/cli" - "forge.lthn.ai/core/go-i18n" -) - -// task:update command flags -var ( - taskUpdateStatus string - taskUpdateProgress int - taskUpdateNotes string -) - -// task:complete command flags -var ( - taskCompleteOutput string - taskCompleteFailed bool - taskCompleteErrorMsg string -) - -var taskUpdateCmd = &cli.Command{ - Use: "task:update [task-id]", - Short: i18n.T("cmd.ai.task_update.short"), - Long: i18n.T("cmd.ai.task_update.long"), - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - taskID := args[0] - - if taskUpdateStatus == "" && taskUpdateProgress == 0 && taskUpdateNotes == "" { - return cli.Err("%s", i18n.T("cmd.ai.task_update.flag_required")) - } - - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - update := agentic.TaskUpdate{ - Progress: taskUpdateProgress, - Notes: taskUpdateNotes, - } - if taskUpdateStatus != "" { - update.Status = agentic.TaskStatus(taskUpdateStatus) - } - - if err := client.UpdateTask(ctx, taskID, update); err != nil { - return cli.WrapVerb(err, "update", "task") - } - - cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.update", "task")) - return nil - }, -} - -var taskCompleteCmd = &cli.Command{ - Use: "task:complete [task-id]", - Short: i18n.T("cmd.ai.task_complete.short"), - Long: i18n.T("cmd.ai.task_complete.long"), - Args: cli.ExactArgs(1), - RunE: func(cmd *cli.Command, args []string) error { - taskID := args[0] - - cfg, err := agentic.LoadConfig("") - if err != nil { - return cli.WrapVerb(err, "load", "config") - } - - client := agentic.NewClientFromConfig(cfg) - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - result := agentic.TaskResult{ - Success: !taskCompleteFailed, - Output: taskCompleteOutput, - ErrorMessage: taskCompleteErrorMsg, - } - - if err := client.CompleteTask(ctx, taskID, result); err != nil { - return cli.WrapVerb(err, "complete", "task") - } - - // Record task completion event - _ = ai.Record(ai.Event{ - Type: "task.completed", - AgentID: cfg.AgentID, - Data: map[string]any{"task_id": taskID, "success": !taskCompleteFailed}, - }) - - if taskCompleteFailed { - cli.Print("%s %s\n", errorStyle.Render(">>"), i18n.T("cmd.ai.task_complete.failed", map[string]any{"ID": taskID})) - } else { - cli.Print("%s %s\n", successStyle.Render(">>"), i18n.T("i18n.done.complete", "task")) - } - return nil - }, -} - -func initUpdatesFlags() { - // task:update command flags - taskUpdateCmd.Flags().StringVar(&taskUpdateStatus, "status", "", i18n.T("cmd.ai.task_update.flag.status")) - taskUpdateCmd.Flags().IntVar(&taskUpdateProgress, "progress", 0, i18n.T("cmd.ai.task_update.flag.progress")) - taskUpdateCmd.Flags().StringVar(&taskUpdateNotes, "notes", "", i18n.T("cmd.ai.task_update.flag.notes")) - - // task:complete command flags - taskCompleteCmd.Flags().StringVar(&taskCompleteOutput, "output", "", i18n.T("cmd.ai.task_complete.flag.output")) - taskCompleteCmd.Flags().BoolVar(&taskCompleteFailed, "failed", false, i18n.T("cmd.ai.task_complete.flag.failed")) - taskCompleteCmd.Flags().StringVar(&taskCompleteErrorMsg, "error", "", i18n.T("cmd.ai.task_complete.flag.error")) -} diff --git a/cmd/workspace/cmd.go b/cmd/workspace/cmd.go deleted file mode 100644 index d9031cd..0000000 --- a/cmd/workspace/cmd.go +++ /dev/null @@ -1 +0,0 @@ -package workspace diff --git a/cmd/workspace/cmd_agent.go b/cmd/workspace/cmd_agent.go deleted file mode 100644 index 4054103..0000000 --- a/cmd/workspace/cmd_agent.go +++ /dev/null @@ -1,289 +0,0 @@ -// cmd_agent.go manages persistent agent context within task workspaces. -// -// Each agent gets a directory at: -// -// .core/workspace/p{epic}/i{issue}/agents/{provider}/{agent-name}/ -// -// This directory persists across invocations, allowing agents to build -// understanding over time — QA agents accumulate findings, reviewers -// track patterns, implementors record decisions. -// -// Layout: -// -// agents/ -// ├── claude-opus/implementor/ -// │ ├── memory.md # Persistent notes, decisions, context -// │ └── artifacts/ # Generated artifacts (reports, diffs, etc.) -// ├── claude-opus/qa/ -// │ ├── memory.md -// │ └── artifacts/ -// └── gemini/reviewer/ -// └── memory.md -package workspace - -import ( - "encoding/json" - "fmt" - "path/filepath" - "strings" - "time" - - "forge.lthn.ai/core/cli/pkg/cli" - coreio "forge.lthn.ai/core/go-io" - coreerr "forge.lthn.ai/core/go-log" -) - -var ( - agentProvider string - agentName string -) - -func addAgentCommands(parent *cli.Command) { - agentCmd := &cli.Command{ - Use: "agent", - Short: "Manage persistent agent context within task workspaces", - } - - initCmd := &cli.Command{ - Use: "init ", - Short: "Initialize an agent's context directory in the task workspace", - Long: `Creates agents/{provider}/{agent-name}/ with memory.md and artifacts/ -directory. The agent can read/write memory.md across invocations to -build understanding over time.`, - Args: cli.ExactArgs(1), - RunE: runAgentInit, - } - initCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - initCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - _ = initCmd.MarkFlagRequired("epic") - _ = initCmd.MarkFlagRequired("issue") - - agentListCmd := &cli.Command{ - Use: "list", - Short: "List agents in a task workspace", - RunE: runAgentList, - } - agentListCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - agentListCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - _ = agentListCmd.MarkFlagRequired("epic") - _ = agentListCmd.MarkFlagRequired("issue") - - pathCmd := &cli.Command{ - Use: "path ", - Short: "Print the agent's context directory path", - Args: cli.ExactArgs(1), - RunE: runAgentPath, - } - pathCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - pathCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - _ = pathCmd.MarkFlagRequired("epic") - _ = pathCmd.MarkFlagRequired("issue") - - agentCmd.AddCommand(initCmd, agentListCmd, pathCmd) - parent.AddCommand(agentCmd) -} - -// agentContextPath returns the path for an agent's context directory. -func agentContextPath(wsPath, provider, name string) string { - return filepath.Join(wsPath, "agents", provider, name) -} - -// parseAgentID splits "provider/agent-name" into parts. -func parseAgentID(id string) (provider, name string, err error) { - parts := strings.SplitN(id, "/", 2) - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - return "", "", coreerr.E("parseAgentID", "agent ID must be provider/agent-name (e.g. claude-opus/qa)", nil) - } - return parts[0], parts[1], nil -} - -// AgentManifest tracks agent metadata for a task workspace. -type AgentManifest struct { - Provider string `json:"provider"` - Name string `json:"name"` - CreatedAt time.Time `json:"created_at"` - LastSeen time.Time `json:"last_seen"` -} - -func runAgentInit(cmd *cli.Command, args []string) error { - provider, name, err := parseAgentID(args[0]) - if err != nil { - return err - } - - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - if !coreio.Local.IsDir(wsPath) { - return cli.Err("task workspace does not exist: p%d/i%d — create it first with `core workspace task create`", taskEpic, taskIssue) - } - - agentDir := agentContextPath(wsPath, provider, name) - - if coreio.Local.IsDir(agentDir) { - // Update last_seen - updateAgentManifest(agentDir, provider, name) - cli.Print("Agent %s/%s already initialized at p%d/i%d\n", - cli.ValueStyle.Render(provider), cli.ValueStyle.Render(name), taskEpic, taskIssue) - cli.Print("Path: %s\n", cli.DimStyle.Render(agentDir)) - return nil - } - - // Create directory structure - if err := coreio.Local.EnsureDir(agentDir); err != nil { - return coreerr.E("agentInit", "failed to create agent directory", err) - } - if err := coreio.Local.EnsureDir(filepath.Join(agentDir, "artifacts")); err != nil { - return coreerr.E("agentInit", "failed to create artifacts directory", err) - } - - // Create initial memory.md - memoryContent := fmt.Sprintf(`# %s/%s — Issue #%d (EPIC #%d) - -## Context -- **Task workspace:** p%d/i%d -- **Initialized:** %s - -## Notes - - -`, provider, name, taskIssue, taskEpic, taskEpic, taskIssue, time.Now().Format(time.RFC3339)) - - if err := coreio.Local.Write(filepath.Join(agentDir, "memory.md"), memoryContent); err != nil { - return coreerr.E("agentInit", "failed to create memory.md", err) - } - - // Write manifest - updateAgentManifest(agentDir, provider, name) - - cli.Print("%s Agent %s/%s initialized at p%d/i%d\n", - cli.SuccessStyle.Render("Done:"), - cli.ValueStyle.Render(provider), cli.ValueStyle.Render(name), - taskEpic, taskIssue) - cli.Print("Memory: %s\n", cli.DimStyle.Render(filepath.Join(agentDir, "memory.md"))) - - return nil -} - -func runAgentList(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - agentsDir := filepath.Join(wsPath, "agents") - - if !coreio.Local.IsDir(agentsDir) { - cli.Println("No agents in this workspace.") - return nil - } - - providers, err := coreio.Local.List(agentsDir) - if err != nil { - return coreerr.E("agentList", "failed to list agents", err) - } - - found := false - for _, providerEntry := range providers { - if !providerEntry.IsDir() { - continue - } - providerDir := filepath.Join(agentsDir, providerEntry.Name()) - agents, err := coreio.Local.List(providerDir) - if err != nil { - continue - } - - for _, agentEntry := range agents { - if !agentEntry.IsDir() { - continue - } - found = true - agentDir := filepath.Join(providerDir, agentEntry.Name()) - - // Read manifest for last_seen - lastSeen := "" - manifestPath := filepath.Join(agentDir, "manifest.json") - if data, err := coreio.Local.Read(manifestPath); err == nil { - var m AgentManifest - if json.Unmarshal([]byte(data), &m) == nil { - lastSeen = m.LastSeen.Format("2006-01-02 15:04") - } - } - - // Check if memory has content beyond the template - memorySize := "" - if content, err := coreio.Local.Read(filepath.Join(agentDir, "memory.md")); err == nil { - lines := len(strings.Split(content, "\n")) - memorySize = fmt.Sprintf("%d lines", lines) - } - - cli.Print(" %s/%s %s", - cli.ValueStyle.Render(providerEntry.Name()), - cli.ValueStyle.Render(agentEntry.Name()), - cli.DimStyle.Render(memorySize)) - if lastSeen != "" { - cli.Print(" last: %s", cli.DimStyle.Render(lastSeen)) - } - cli.Print("\n") - } - } - - if !found { - cli.Println("No agents in this workspace.") - } - - return nil -} - -func runAgentPath(cmd *cli.Command, args []string) error { - provider, name, err := parseAgentID(args[0]) - if err != nil { - return err - } - - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - agentDir := agentContextPath(wsPath, provider, name) - - if !coreio.Local.IsDir(agentDir) { - return cli.Err("agent %s/%s not initialized — run `core workspace agent init %s/%s`", provider, name, provider, name) - } - - // Print just the path (useful for scripting: cd $(core workspace agent path ...)) - cli.Text(agentDir) - return nil -} - -func updateAgentManifest(agentDir, provider, name string) { - now := time.Now() - manifest := AgentManifest{ - Provider: provider, - Name: name, - CreatedAt: now, - LastSeen: now, - } - - // Try to preserve created_at from existing manifest - manifestPath := filepath.Join(agentDir, "manifest.json") - if data, err := coreio.Local.Read(manifestPath); err == nil { - var existing AgentManifest - if json.Unmarshal([]byte(data), &existing) == nil { - manifest.CreatedAt = existing.CreatedAt - } - } - - data, err := json.MarshalIndent(manifest, "", " ") - if err != nil { - return - } - _ = coreio.Local.Write(manifestPath, string(data)) -} diff --git a/cmd/workspace/cmd_agent_test.go b/cmd/workspace/cmd_agent_test.go deleted file mode 100644 index e414cb0..0000000 --- a/cmd/workspace/cmd_agent_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package workspace - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseAgentID_Good(t *testing.T) { - provider, name, err := parseAgentID("claude-opus/qa") - require.NoError(t, err) - assert.Equal(t, "claude-opus", provider) - assert.Equal(t, "qa", name) -} - -func TestParseAgentID_Bad(t *testing.T) { - tests := []string{ - "noslash", - "/missing-provider", - "missing-name/", - "", - } - for _, id := range tests { - _, _, err := parseAgentID(id) - assert.Error(t, err, "expected error for: %q", id) - } -} - -func TestAgentContextPath(t *testing.T) { - path := agentContextPath("/ws/p101/i343", "claude-opus", "qa") - assert.Equal(t, "/ws/p101/i343/agents/claude-opus/qa", path) -} - -func TestUpdateAgentManifest_Good(t *testing.T) { - tmp := t.TempDir() - agentDir := filepath.Join(tmp, "agents", "test-provider", "test-agent") - require.NoError(t, os.MkdirAll(agentDir, 0755)) - - updateAgentManifest(agentDir, "test-provider", "test-agent") - - data, err := os.ReadFile(filepath.Join(agentDir, "manifest.json")) - require.NoError(t, err) - - var m AgentManifest - require.NoError(t, json.Unmarshal(data, &m)) - assert.Equal(t, "test-provider", m.Provider) - assert.Equal(t, "test-agent", m.Name) - assert.False(t, m.CreatedAt.IsZero()) - assert.False(t, m.LastSeen.IsZero()) -} - -func TestUpdateAgentManifest_PreservesCreatedAt(t *testing.T) { - tmp := t.TempDir() - agentDir := filepath.Join(tmp, "agents", "p", "a") - require.NoError(t, os.MkdirAll(agentDir, 0755)) - - // First call sets created_at - updateAgentManifest(agentDir, "p", "a") - - data, err := os.ReadFile(filepath.Join(agentDir, "manifest.json")) - require.NoError(t, err) - var first AgentManifest - require.NoError(t, json.Unmarshal(data, &first)) - - // Second call should preserve created_at - updateAgentManifest(agentDir, "p", "a") - - data, err = os.ReadFile(filepath.Join(agentDir, "manifest.json")) - require.NoError(t, err) - var second AgentManifest - require.NoError(t, json.Unmarshal(data, &second)) - - assert.Equal(t, first.CreatedAt, second.CreatedAt) - assert.True(t, second.LastSeen.After(first.CreatedAt) || second.LastSeen.Equal(first.CreatedAt)) -} diff --git a/cmd/workspace/cmd_prep.go b/cmd/workspace/cmd_prep.go deleted file mode 100644 index 0a0c412..0000000 --- a/cmd/workspace/cmd_prep.go +++ /dev/null @@ -1,543 +0,0 @@ -// cmd_prep.go implements the `workspace prep` command. -// -// Prepares an agent workspace with wiki KB, protocol specs, a TODO from a -// Forge issue, and vector-recalled context from OpenBrain. All output goes -// to .core/ in the current directory, matching the convention used by -// KBConfig (go-scm) and build/release config. - -package workspace - -import ( - "context" - "encoding/base64" - "encoding/json" - "fmt" - "net/http" - "net/url" - "os" - "path/filepath" - "regexp" - "strings" - "time" - - "forge.lthn.ai/core/agent/pkg/lifecycle" - "forge.lthn.ai/core/cli/pkg/cli" - coreio "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/forge" -) - -var ( - prepRepo string - prepIssue int - prepOrg string - prepOutput string - prepSpecsPath string - prepDryRun bool -) - -func addPrepCommands(parent *cli.Command) { - prepCmd := &cli.Command{ - Use: "prep", - Short: "Prepare agent workspace with wiki KB, specs, TODO, and vector context", - Long: `Fetches wiki pages from Forge, copies protocol specs, generates a task -file from a Forge issue, and queries OpenBrain for relevant context. -All output is written to .core/ in the current directory.`, - RunE: runPrep, - } - - prepCmd.Flags().StringVar(&prepRepo, "repo", "", "Forge repo name (e.g. go-ai)") - prepCmd.Flags().IntVar(&prepIssue, "issue", 0, "Issue number to build TODO from") - prepCmd.Flags().StringVar(&prepOrg, "org", "core", "Forge organisation") - prepCmd.Flags().StringVar(&prepOutput, "output", "", "Output directory (default: ./.core)") - prepCmd.Flags().StringVar(&prepSpecsPath, "specs-path", "", "Path to specs dir") - prepCmd.Flags().BoolVar(&prepDryRun, "dry-run", false, "Preview without writing files") - _ = prepCmd.MarkFlagRequired("repo") - - parent.AddCommand(prepCmd) -} - -func runPrep(cmd *cli.Command, args []string) error { - ctx := context.Background() - - // Resolve output directory - outputDir := prepOutput - if outputDir == "" { - cwd, err := os.Getwd() - if err != nil { - return cli.Err("failed to get working directory") - } - outputDir = filepath.Join(cwd, ".core") - } - - // Resolve specs path - specsPath := prepSpecsPath - if specsPath == "" { - home, err := os.UserHomeDir() - if err == nil { - specsPath = filepath.Join(home, "Code", "specs") - } - } - - // Resolve Forge connection - forgeURL, forgeToken, err := forge.ResolveConfig("", "") - if err != nil { - return log.E("workspace.prep", "failed to resolve Forge config", err) - } - if forgeToken == "" { - return log.E("workspace.prep", "no Forge token configured — set FORGE_TOKEN or run: core forge login", nil) - } - - cli.Print("Preparing workspace for %s/%s\n", cli.ValueStyle.Render(prepOrg), cli.ValueStyle.Render(prepRepo)) - cli.Print("Output: %s\n", cli.DimStyle.Render(outputDir)) - - if prepDryRun { - cli.Print("%s No files will be written.\n", cli.WarningStyle.Render("[DRY RUN]")) - } - fmt.Println() - - // Create output directory structure - if !prepDryRun { - if err := coreio.Local.EnsureDir(filepath.Join(outputDir, "kb")); err != nil { - return log.E("workspace.prep", "failed to create kb directory", err) - } - if err := coreio.Local.EnsureDir(filepath.Join(outputDir, "specs")); err != nil { - return log.E("workspace.prep", "failed to create specs directory", err) - } - } - - // Step 1: Pull wiki pages - wikiCount, err := prepPullWiki(ctx, forgeURL, forgeToken, prepOrg, prepRepo, outputDir, prepDryRun) - if err != nil { - cli.Print("%s wiki: %v\n", cli.WarningStyle.Render("warn"), err) - } - - // Step 2: Copy spec files - specsCount := prepCopySpecs(specsPath, outputDir, prepDryRun) - - // Step 3: Generate TODO from issue - var issueTitle, issueBody string - if prepIssue > 0 { - issueTitle, issueBody, err = prepGenerateTodo(ctx, forgeURL, forgeToken, prepOrg, prepRepo, prepIssue, outputDir, prepDryRun) - if err != nil { - cli.Print("%s todo: %v\n", cli.WarningStyle.Render("warn"), err) - prepGenerateTodoSkeleton(prepOrg, prepRepo, outputDir, prepDryRun) - } - } else { - prepGenerateTodoSkeleton(prepOrg, prepRepo, outputDir, prepDryRun) - } - - // Step 4: Generate context from OpenBrain - contextCount := prepGenerateContext(ctx, prepRepo, issueTitle, issueBody, outputDir, prepDryRun) - - // Summary - fmt.Println() - prefix := "" - if prepDryRun { - prefix = "[DRY RUN] " - } - cli.Print("%s%s\n", prefix, cli.SuccessStyle.Render("Workspace prep complete:")) - cli.Print(" Wiki pages: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d", wikiCount))) - cli.Print(" Spec files: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d", specsCount))) - if issueTitle != "" { - cli.Print(" TODO: %s\n", cli.ValueStyle.Render(fmt.Sprintf("from issue #%d", prepIssue))) - } else { - cli.Print(" TODO: %s\n", cli.DimStyle.Render("skeleton")) - } - cli.Print(" Context: %s\n", cli.ValueStyle.Render(fmt.Sprintf("%d memories", contextCount))) - - return nil -} - -// --- Step 1: Pull wiki pages from Forge API --- - -type wikiPageRef struct { - Title string `json:"title"` - SubURL string `json:"sub_url"` -} - -type wikiPageContent struct { - ContentBase64 string `json:"content_base64"` -} - -func prepPullWiki(ctx context.Context, forgeURL, token, org, repo, outputDir string, dryRun bool) (int, error) { - cli.Print("Fetching wiki pages for %s/%s...\n", org, repo) - - endpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/pages", forgeURL, org, repo) - resp, err := forgeGet(ctx, endpoint, token) - if err != nil { - return 0, log.E("workspace.prep.wiki", "API request failed", err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusNotFound { - cli.Print(" %s No wiki found for %s\n", cli.WarningStyle.Render("warn"), repo) - if !dryRun { - content := fmt.Sprintf("# No wiki found for %s\n\nThis repo has no wiki pages on Forge.\n", repo) - _ = coreio.Local.Write(filepath.Join(outputDir, "kb", "README.md"), content) - } - return 0, nil - } - - if resp.StatusCode != http.StatusOK { - return 0, log.E("workspace.prep.wiki", fmt.Sprintf("API error: %d", resp.StatusCode), nil) - } - - var pages []wikiPageRef - if err := json.NewDecoder(resp.Body).Decode(&pages); err != nil { - return 0, log.E("workspace.prep.wiki", "failed to decode pages", err) - } - - if len(pages) == 0 { - cli.Print(" %s Wiki exists but has no pages.\n", cli.WarningStyle.Render("warn")) - return 0, nil - } - - count := 0 - for _, page := range pages { - title := page.Title - if title == "" { - title = "Untitled" - } - subURL := page.SubURL - if subURL == "" { - subURL = title - } - - if dryRun { - cli.Print(" [would fetch] %s\n", title) - count++ - continue - } - - pageEndpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/wiki/page/%s", - forgeURL, org, repo, url.PathEscape(subURL)) - pageResp, err := forgeGet(ctx, pageEndpoint, token) - if err != nil || pageResp.StatusCode != http.StatusOK { - cli.Print(" %s Failed to fetch: %s\n", cli.WarningStyle.Render("warn"), title) - if pageResp != nil { - pageResp.Body.Close() - } - continue - } - - var pageData wikiPageContent - if err := json.NewDecoder(pageResp.Body).Decode(&pageData); err != nil { - pageResp.Body.Close() - continue - } - pageResp.Body.Close() - - if pageData.ContentBase64 == "" { - continue - } - - decoded, err := base64.StdEncoding.DecodeString(pageData.ContentBase64) - if err != nil { - continue - } - - filename := sanitiseFilename(title) + ".md" - _ = coreio.Local.Write(filepath.Join(outputDir, "kb", filename), string(decoded)) - cli.Print(" %s\n", title) - count++ - } - - cli.Print(" %d wiki page(s) saved to kb/\n", count) - return count, nil -} - -// --- Step 2: Copy protocol spec files --- - -func prepCopySpecs(specsPath, outputDir string, dryRun bool) int { - cli.Print("Copying spec files...\n") - - specFiles := []string{"AGENT_CONTEXT.md", "TASK_PROTOCOL.md"} - count := 0 - - for _, file := range specFiles { - source := filepath.Join(specsPath, file) - if !coreio.Local.IsFile(source) { - cli.Print(" %s Not found: %s\n", cli.WarningStyle.Render("warn"), source) - continue - } - - if dryRun { - cli.Print(" [would copy] %s\n", file) - count++ - continue - } - - content, err := coreio.Local.Read(source) - if err != nil { - cli.Print(" %s Failed to read: %s\n", cli.WarningStyle.Render("warn"), file) - continue - } - - dest := filepath.Join(outputDir, "specs", file) - if err := coreio.Local.Write(dest, content); err != nil { - cli.Print(" %s Failed to write: %s\n", cli.WarningStyle.Render("warn"), file) - continue - } - - cli.Print(" %s\n", file) - count++ - } - - cli.Print(" %d spec file(s) copied.\n", count) - return count -} - -// --- Step 3: Generate TODO from Forge issue --- - -type forgeIssue struct { - Title string `json:"title"` - Body string `json:"body"` -} - -func prepGenerateTodo(ctx context.Context, forgeURL, token, org, repo string, issueNum int, outputDir string, dryRun bool) (string, string, error) { - cli.Print("Generating TODO from issue #%d...\n", issueNum) - - endpoint := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d", forgeURL, org, repo, issueNum) - resp, err := forgeGet(ctx, endpoint, token) - if err != nil { - return "", "", log.E("workspace.prep.todo", "issue API request failed", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", "", log.E("workspace.prep.todo", fmt.Sprintf("failed to fetch issue #%d: %d", issueNum, resp.StatusCode), nil) - } - - var issue forgeIssue - if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil { - return "", "", log.E("workspace.prep.todo", "failed to decode issue", err) - } - - title := issue.Title - if title == "" { - title = "Untitled" - } - - objective := extractObjective(issue.Body) - checklist := extractChecklist(issue.Body) - - var b strings.Builder - fmt.Fprintf(&b, "# TASK: %s\n\n", title) - fmt.Fprintf(&b, "**Status:** ready\n") - fmt.Fprintf(&b, "**Source:** %s/%s/%s/issues/%d\n", forgeURL, org, repo, issueNum) - fmt.Fprintf(&b, "**Created:** %s\n", time.Now().Format("2006-01-02 15:04:05")) - fmt.Fprintf(&b, "**Repo:** %s/%s\n", org, repo) - b.WriteString("\n---\n\n") - - fmt.Fprintf(&b, "## Objective\n\n%s\n", objective) - b.WriteString("\n---\n\n") - - b.WriteString("## Acceptance Criteria\n\n") - if len(checklist) > 0 { - for _, item := range checklist { - fmt.Fprintf(&b, "- [ ] %s\n", item) - } - } else { - b.WriteString("_No checklist items found in issue. Agent should define acceptance criteria._\n") - } - b.WriteString("\n---\n\n") - - b.WriteString("## Implementation Checklist\n\n") - b.WriteString("_To be filled by the agent during planning._\n") - b.WriteString("\n---\n\n") - - b.WriteString("## Notes\n\n") - b.WriteString("Full issue body preserved below for reference.\n\n") - b.WriteString("
\nOriginal Issue\n\n") - b.WriteString(issue.Body) - b.WriteString("\n\n
\n") - - if dryRun { - cli.Print(" [would write] todo.md from: %s\n", title) - } else { - if err := coreio.Local.Write(filepath.Join(outputDir, "todo.md"), b.String()); err != nil { - return title, issue.Body, log.E("workspace.prep.todo", "failed to write todo.md", err) - } - cli.Print(" todo.md generated from: %s\n", title) - } - - return title, issue.Body, nil -} - -func prepGenerateTodoSkeleton(org, repo, outputDir string, dryRun bool) { - var b strings.Builder - b.WriteString("# TASK: [Define task]\n\n") - fmt.Fprintf(&b, "**Status:** ready\n") - fmt.Fprintf(&b, "**Created:** %s\n", time.Now().Format("2006-01-02 15:04:05")) - fmt.Fprintf(&b, "**Repo:** %s/%s\n", org, repo) - b.WriteString("\n---\n\n") - b.WriteString("## Objective\n\n_Define the objective._\n") - b.WriteString("\n---\n\n") - b.WriteString("## Acceptance Criteria\n\n- [ ] _Define criteria_\n") - b.WriteString("\n---\n\n") - b.WriteString("## Implementation Checklist\n\n_To be filled by the agent._\n") - - if dryRun { - cli.Print(" [would write] todo.md skeleton\n") - } else { - _ = coreio.Local.Write(filepath.Join(outputDir, "todo.md"), b.String()) - cli.Print(" todo.md skeleton generated (no --issue provided)\n") - } -} - -// --- Step 4: Generate context from OpenBrain --- - -func prepGenerateContext(ctx context.Context, repo, issueTitle, issueBody, outputDir string, dryRun bool) int { - cli.Print("Querying vector DB for context...\n") - - apiURL := os.Getenv("CORE_API_URL") - if apiURL == "" { - apiURL = "http://localhost:8000" - } - apiToken := os.Getenv("CORE_API_TOKEN") - - client := lifecycle.NewClient(apiURL, apiToken) - - // Query 1: Repo-specific knowledge - repoResult, err := client.Recall(ctx, lifecycle.RecallRequest{ - Query: "How does " + repo + " work? Architecture and key interfaces.", - TopK: 10, - Project: repo, - }) - if err != nil { - cli.Print(" %s BrainService unavailable: %v\n", cli.WarningStyle.Render("warn"), err) - writeBrainUnavailable(repo, outputDir, dryRun) - return 0 - } - - repoMemories := repoResult.Memories - repoScores := repoResult.Scores - - // Query 2: Issue-specific context - var issueMemories []lifecycle.Memory - var issueScores map[string]float64 - if issueTitle != "" { - query := issueTitle - if len(issueBody) > 500 { - query += " " + issueBody[:500] - } else if issueBody != "" { - query += " " + issueBody - } - - issueResult, err := client.Recall(ctx, lifecycle.RecallRequest{ - Query: query, - TopK: 5, - }) - if err == nil { - issueMemories = issueResult.Memories - issueScores = issueResult.Scores - } - } - - totalMemories := len(repoMemories) + len(issueMemories) - - var b strings.Builder - fmt.Fprintf(&b, "# Agent Context — %s\n\n", repo) - b.WriteString("> Auto-generated by `core workspace prep`. Query the vector DB for more.\n\n") - - b.WriteString("## Repo Knowledge\n\n") - if len(repoMemories) > 0 { - for i, mem := range repoMemories { - score := repoScores[mem.ID] - project := mem.Project - if project == "" { - project = "unknown" - } - memType := mem.Type - if memType == "" { - memType = "memory" - } - fmt.Fprintf(&b, "### %d. %s [%s] (score: %.3f)\n\n", i+1, project, memType, score) - fmt.Fprintf(&b, "%s\n\n", mem.Content) - } - } else { - b.WriteString("_No repo-specific memories found. The vector DB may not have been seeded for this repo._\n\n") - } - - b.WriteString("## Task-Relevant Context\n\n") - if len(issueMemories) > 0 { - for i, mem := range issueMemories { - score := issueScores[mem.ID] - project := mem.Project - if project == "" { - project = "unknown" - } - memType := mem.Type - if memType == "" { - memType = "memory" - } - fmt.Fprintf(&b, "### %d. %s [%s] (score: %.3f)\n\n", i+1, project, memType, score) - fmt.Fprintf(&b, "%s\n\n", mem.Content) - } - } else if issueTitle != "" { - b.WriteString("_No task-relevant memories found._\n\n") - } else { - b.WriteString("_No issue provided — skipped task-specific recall._\n\n") - } - - if dryRun { - cli.Print(" [would write] context.md with %d memories\n", totalMemories) - } else { - _ = coreio.Local.Write(filepath.Join(outputDir, "context.md"), b.String()) - cli.Print(" context.md generated with %d memories\n", totalMemories) - } - - return totalMemories -} - -func writeBrainUnavailable(repo, outputDir string, dryRun bool) { - var b strings.Builder - fmt.Fprintf(&b, "# Agent Context — %s\n\n", repo) - b.WriteString("> Vector DB was unavailable when this workspace was prepared.\n") - b.WriteString("> Run `core workspace prep` again once Ollama/Qdrant are reachable.\n") - - if !dryRun { - _ = coreio.Local.Write(filepath.Join(outputDir, "context.md"), b.String()) - } -} - -// --- Helpers --- - -func forgeGet(ctx context.Context, endpoint, token string) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, err - } - req.Header.Set("Authorization", "token "+token) - client := &http.Client{Timeout: 30 * time.Second} - return client.Do(req) -} - -var nonAlphanumeric = regexp.MustCompile(`[^a-zA-Z0-9_\-.]`) - -func sanitiseFilename(title string) string { - return nonAlphanumeric.ReplaceAllString(title, "-") -} - -func extractObjective(body string) string { - if body == "" { - return "_No description provided._" - } - parts := strings.SplitN(body, "\n\n", 2) - first := strings.TrimSpace(parts[0]) - if len(first) > 500 { - return first[:497] + "..." - } - return first -} - -func extractChecklist(body string) []string { - re := regexp.MustCompile(`- \[[ xX]\] (.+)`) - matches := re.FindAllStringSubmatch(body, -1) - var items []string - for _, m := range matches { - items = append(items, strings.TrimSpace(m[1])) - } - return items -} diff --git a/cmd/workspace/cmd_task.go b/cmd/workspace/cmd_task.go deleted file mode 100644 index 7640f80..0000000 --- a/cmd/workspace/cmd_task.go +++ /dev/null @@ -1,465 +0,0 @@ -// cmd_task.go implements task workspace isolation using git worktrees. -// -// Each task gets an isolated workspace at .core/workspace/p{epic}/i{issue}/ -// containing git worktrees of required repos. This prevents agents from -// writing to the implementor's working tree. -// -// Safety checks enforce that workspaces cannot be removed if they contain -// uncommitted changes or unpushed branches. -package workspace - -import ( - "context" - "fmt" - "os/exec" - "path/filepath" - "strconv" - "strings" - - "forge.lthn.ai/core/cli/pkg/cli" - coreio "forge.lthn.ai/core/go-io" - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/repos" -) - -var ( - taskEpic int - taskIssue int - taskRepos []string - taskForce bool - taskBranch string -) - -func addTaskCommands(parent *cli.Command) { - taskCmd := &cli.Command{ - Use: "task", - Short: "Manage isolated task workspaces for agents", - } - - createCmd := &cli.Command{ - Use: "create", - Short: "Create an isolated task workspace with git worktrees", - Long: `Creates a workspace at .core/workspace/p{epic}/i{issue}/ with git -worktrees for each specified repo. Each worktree gets a fresh branch -(issue/{id} by default) so agents work in isolation.`, - RunE: runTaskCreate, - } - createCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - createCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - createCmd.Flags().StringSliceVar(&taskRepos, "repo", nil, "Repos to include (default: all from registry)") - createCmd.Flags().StringVar(&taskBranch, "branch", "", "Branch name (default: issue/{issue})") - _ = createCmd.MarkFlagRequired("epic") - _ = createCmd.MarkFlagRequired("issue") - - removeCmd := &cli.Command{ - Use: "remove", - Short: "Remove a task workspace (with safety checks)", - Long: `Removes a task workspace after checking for uncommitted changes and -unpushed branches. Use --force to skip safety checks.`, - RunE: runTaskRemove, - } - removeCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - removeCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - removeCmd.Flags().BoolVar(&taskForce, "force", false, "Skip safety checks") - _ = removeCmd.MarkFlagRequired("epic") - _ = removeCmd.MarkFlagRequired("issue") - - listCmd := &cli.Command{ - Use: "list", - Short: "List all task workspaces", - RunE: runTaskList, - } - - statusCmd := &cli.Command{ - Use: "status", - Short: "Show status of a task workspace", - RunE: runTaskStatus, - } - statusCmd.Flags().IntVar(&taskEpic, "epic", 0, "Epic/project number") - statusCmd.Flags().IntVar(&taskIssue, "issue", 0, "Issue number") - _ = statusCmd.MarkFlagRequired("epic") - _ = statusCmd.MarkFlagRequired("issue") - - addAgentCommands(taskCmd) - - taskCmd.AddCommand(createCmd, removeCmd, listCmd, statusCmd) - parent.AddCommand(taskCmd) -} - -// taskWorkspacePath returns the path for a task workspace. -func taskWorkspacePath(root string, epic, issue int) string { - return filepath.Join(root, ".core", "workspace", fmt.Sprintf("p%d", epic), fmt.Sprintf("i%d", issue)) -} - -func runTaskCreate(cmd *cli.Command, args []string) error { - ctx := context.Background() - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace — run from workspace root or a package directory") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - - if coreio.Local.IsDir(wsPath) { - return cli.Err("task workspace already exists: %s", wsPath) - } - - branch := taskBranch - if branch == "" { - branch = fmt.Sprintf("issue/%d", taskIssue) - } - - // Determine repos to include - repoNames := taskRepos - if len(repoNames) == 0 { - repoNames, err = registryRepoNames(root) - if err != nil { - return coreerr.E("taskCreate", "failed to load registry", err) - } - } - - if len(repoNames) == 0 { - return cli.Err("no repos specified and no registry found") - } - - // Resolve package paths - config, _ := LoadConfig(root) - pkgDir := "./packages" - if config != nil && config.PackagesDir != "" { - pkgDir = config.PackagesDir - } - if !filepath.IsAbs(pkgDir) { - pkgDir = filepath.Join(root, pkgDir) - } - - if err := coreio.Local.EnsureDir(wsPath); err != nil { - return coreerr.E("taskCreate", "failed to create workspace directory", err) - } - - cli.Print("Creating task workspace: %s\n", cli.ValueStyle.Render(fmt.Sprintf("p%d/i%d", taskEpic, taskIssue))) - cli.Print("Branch: %s\n", cli.ValueStyle.Render(branch)) - cli.Print("Path: %s\n\n", cli.DimStyle.Render(wsPath)) - - var created, skipped int - for _, repoName := range repoNames { - repoPath := filepath.Join(pkgDir, repoName) - if !coreio.Local.IsDir(filepath.Join(repoPath, ".git")) { - cli.Print(" %s %s (not cloned, skipping)\n", cli.DimStyle.Render("·"), repoName) - skipped++ - continue - } - - worktreePath := filepath.Join(wsPath, repoName) - cli.Print(" %s %s... ", cli.DimStyle.Render("·"), repoName) - - if err := createWorktree(ctx, repoPath, worktreePath, branch); err != nil { - cli.Print("%s\n", cli.ErrorStyle.Render("x "+err.Error())) - skipped++ - continue - } - - cli.Print("%s\n", cli.SuccessStyle.Render("ok")) - created++ - } - - cli.Print("\n%s %d worktrees created", cli.SuccessStyle.Render("Done:"), created) - if skipped > 0 { - cli.Print(", %d skipped", skipped) - } - cli.Print("\n") - - return nil -} - -func runTaskRemove(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - if !coreio.Local.IsDir(wsPath) { - return cli.Err("task workspace does not exist: p%d/i%d", taskEpic, taskIssue) - } - - if !taskForce { - dirty, reasons := checkWorkspaceSafety(wsPath) - if dirty { - cli.Print("%s Cannot remove workspace p%d/i%d:\n", cli.ErrorStyle.Render("Blocked:"), taskEpic, taskIssue) - for _, r := range reasons { - cli.Print(" %s %s\n", cli.ErrorStyle.Render("·"), r) - } - cli.Print("\nUse --force to override or resolve the issues first.\n") - return coreerr.E("taskRemove", "workspace has unresolved changes", nil) - } - } - - // Remove worktrees first (so git knows they're gone) - entries, err := coreio.Local.List(wsPath) - if err != nil { - return coreerr.E("taskRemove", "failed to list workspace", err) - } - - config, _ := LoadConfig(root) - pkgDir := "./packages" - if config != nil && config.PackagesDir != "" { - pkgDir = config.PackagesDir - } - if !filepath.IsAbs(pkgDir) { - pkgDir = filepath.Join(root, pkgDir) - } - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - worktreePath := filepath.Join(wsPath, entry.Name()) - repoPath := filepath.Join(pkgDir, entry.Name()) - - // Remove worktree from git - if coreio.Local.IsDir(filepath.Join(repoPath, ".git")) { - removeWorktree(repoPath, worktreePath) - } - } - - // Remove the workspace directory - if err := coreio.Local.DeleteAll(wsPath); err != nil { - return coreerr.E("taskRemove", "failed to remove workspace directory", err) - } - - // Clean up empty parent (p{epic}/) if it's now empty - epicDir := filepath.Dir(wsPath) - if entries, err := coreio.Local.List(epicDir); err == nil && len(entries) == 0 { - coreio.Local.DeleteAll(epicDir) - } - - cli.Print("%s Removed workspace p%d/i%d\n", cli.SuccessStyle.Render("Done:"), taskEpic, taskIssue) - return nil -} - -func runTaskList(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsRoot := filepath.Join(root, ".core", "workspace") - if !coreio.Local.IsDir(wsRoot) { - cli.Println("No task workspaces found.") - return nil - } - - epics, err := coreio.Local.List(wsRoot) - if err != nil { - return coreerr.E("taskList", "failed to list workspaces", err) - } - - found := false - for _, epicEntry := range epics { - if !epicEntry.IsDir() || !strings.HasPrefix(epicEntry.Name(), "p") { - continue - } - epicDir := filepath.Join(wsRoot, epicEntry.Name()) - issues, err := coreio.Local.List(epicDir) - if err != nil { - continue - } - for _, issueEntry := range issues { - if !issueEntry.IsDir() || !strings.HasPrefix(issueEntry.Name(), "i") { - continue - } - found = true - wsPath := filepath.Join(epicDir, issueEntry.Name()) - - // Count worktrees - entries, _ := coreio.Local.List(wsPath) - dirCount := 0 - for _, e := range entries { - if e.IsDir() { - dirCount++ - } - } - - // Check safety - dirty, _ := checkWorkspaceSafety(wsPath) - status := cli.SuccessStyle.Render("clean") - if dirty { - status = cli.ErrorStyle.Render("dirty") - } - - cli.Print(" %s/%s %d repos %s\n", - epicEntry.Name(), issueEntry.Name(), - dirCount, status) - } - } - - if !found { - cli.Println("No task workspaces found.") - } - - return nil -} - -func runTaskStatus(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - wsPath := taskWorkspacePath(root, taskEpic, taskIssue) - if !coreio.Local.IsDir(wsPath) { - return cli.Err("task workspace does not exist: p%d/i%d", taskEpic, taskIssue) - } - - cli.Print("Workspace: %s\n", cli.ValueStyle.Render(fmt.Sprintf("p%d/i%d", taskEpic, taskIssue))) - cli.Print("Path: %s\n\n", cli.DimStyle.Render(wsPath)) - - entries, err := coreio.Local.List(wsPath) - if err != nil { - return coreerr.E("taskStatus", "failed to list workspace", err) - } - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - worktreePath := filepath.Join(wsPath, entry.Name()) - - // Get branch - branch := gitOutput(worktreePath, "rev-parse", "--abbrev-ref", "HEAD") - branch = strings.TrimSpace(branch) - - // Get status - status := gitOutput(worktreePath, "status", "--porcelain") - statusLabel := cli.SuccessStyle.Render("clean") - if strings.TrimSpace(status) != "" { - lines := len(strings.Split(strings.TrimSpace(status), "\n")) - statusLabel = cli.ErrorStyle.Render(fmt.Sprintf("%d changes", lines)) - } - - // Get unpushed - unpushed := gitOutput(worktreePath, "log", "--oneline", "@{u}..HEAD") - unpushedLabel := "" - if trimmed := strings.TrimSpace(unpushed); trimmed != "" { - count := len(strings.Split(trimmed, "\n")) - unpushedLabel = cli.WarningStyle.Render(fmt.Sprintf(" %d unpushed", count)) - } - - cli.Print(" %s %s %s%s\n", - cli.RepoStyle.Render(entry.Name()), - cli.DimStyle.Render(branch), - statusLabel, - unpushedLabel) - } - - return nil -} - -// createWorktree adds a git worktree at worktreePath for the given branch. -func createWorktree(ctx context.Context, repoPath, worktreePath, branch string) error { - // Check if branch exists on remote first - cmd := exec.CommandContext(ctx, "git", "worktree", "add", "-b", branch, worktreePath) - cmd.Dir = repoPath - output, err := cmd.CombinedOutput() - if err != nil { - errStr := strings.TrimSpace(string(output)) - // If branch already exists, try without -b - if strings.Contains(errStr, "already exists") { - cmd = exec.CommandContext(ctx, "git", "worktree", "add", worktreePath, branch) - cmd.Dir = repoPath - output, err = cmd.CombinedOutput() - if err != nil { - return coreerr.E("createWorktree", strings.TrimSpace(string(output)), nil) - } - return nil - } - return coreerr.E("createWorktree", errStr, nil) - } - return nil -} - -// removeWorktree removes a git worktree. -func removeWorktree(repoPath, worktreePath string) { - cmd := exec.Command("git", "worktree", "remove", worktreePath) - cmd.Dir = repoPath - _ = cmd.Run() - - // Prune stale worktrees - cmd = exec.Command("git", "worktree", "prune") - cmd.Dir = repoPath - _ = cmd.Run() -} - -// checkWorkspaceSafety checks all worktrees in a workspace for uncommitted/unpushed changes. -func checkWorkspaceSafety(wsPath string) (dirty bool, reasons []string) { - entries, err := coreio.Local.List(wsPath) - if err != nil { - return false, nil - } - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - worktreePath := filepath.Join(wsPath, entry.Name()) - - // Check for uncommitted changes - status := gitOutput(worktreePath, "status", "--porcelain") - if strings.TrimSpace(status) != "" { - dirty = true - reasons = append(reasons, fmt.Sprintf("%s: has uncommitted changes", entry.Name())) - } - - // Check for unpushed commits - unpushed := gitOutput(worktreePath, "log", "--oneline", "@{u}..HEAD") - if strings.TrimSpace(unpushed) != "" { - dirty = true - count := len(strings.Split(strings.TrimSpace(unpushed), "\n")) - reasons = append(reasons, fmt.Sprintf("%s: %d unpushed commits", entry.Name(), count)) - } - } - - return dirty, reasons -} - -// gitOutput runs a git command and returns stdout. -func gitOutput(dir string, args ...string) string { - cmd := exec.Command("git", args...) - cmd.Dir = dir - out, _ := cmd.Output() - return string(out) -} - -// registryRepoNames returns repo names from the workspace registry. -func registryRepoNames(root string) ([]string, error) { - // Try to find repos.yaml - regPath, err := repos.FindRegistry(coreio.Local) - if err != nil { - return nil, err - } - - reg, err := repos.LoadRegistry(coreio.Local, regPath) - if err != nil { - return nil, err - } - - var names []string - for _, repo := range reg.List() { - // Only include cloneable repos - if repo.Clone != nil && !*repo.Clone { - continue - } - // Skip meta repos - if repo.Type == "meta" { - continue - } - names = append(names, repo.Name) - } - - return names, nil -} - -// epicBranchName returns the branch name for an EPIC. -func epicBranchName(epicID int) string { - return "epic/" + strconv.Itoa(epicID) -} diff --git a/cmd/workspace/cmd_task_test.go b/cmd/workspace/cmd_task_test.go deleted file mode 100644 index 6340470..0000000 --- a/cmd/workspace/cmd_task_test.go +++ /dev/null @@ -1,109 +0,0 @@ -package workspace - -import ( - "os" - "os/exec" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func setupTestRepo(t *testing.T, dir, name string) string { - t.Helper() - repoPath := filepath.Join(dir, name) - require.NoError(t, os.MkdirAll(repoPath, 0755)) - - cmds := [][]string{ - {"git", "init"}, - {"git", "config", "user.email", "test@test.com"}, - {"git", "config", "user.name", "Test"}, - {"git", "commit", "--allow-empty", "-m", "initial"}, - } - for _, c := range cmds { - cmd := exec.Command(c[0], c[1:]...) - cmd.Dir = repoPath - out, err := cmd.CombinedOutput() - require.NoError(t, err, "cmd %v failed: %s", c, string(out)) - } - return repoPath -} - -func TestTaskWorkspacePath(t *testing.T) { - path := taskWorkspacePath("/home/user/Code/host-uk", 101, 343) - assert.Equal(t, "/home/user/Code/host-uk/.core/workspace/p101/i343", path) -} - -func TestCreateWorktree_Good(t *testing.T) { - tmp := t.TempDir() - repoPath := setupTestRepo(t, tmp, "test-repo") - worktreePath := filepath.Join(tmp, "workspace", "test-repo") - - err := createWorktree(t.Context(), repoPath, worktreePath, "issue/123") - require.NoError(t, err) - - // Verify worktree exists - assert.DirExists(t, worktreePath) - assert.FileExists(t, filepath.Join(worktreePath, ".git")) - - // Verify branch - branch := gitOutput(worktreePath, "rev-parse", "--abbrev-ref", "HEAD") - assert.Equal(t, "issue/123", trimNL(branch)) -} - -func TestCreateWorktree_BranchExists(t *testing.T) { - tmp := t.TempDir() - repoPath := setupTestRepo(t, tmp, "test-repo") - - // Create branch first - cmd := exec.Command("git", "branch", "issue/456") - cmd.Dir = repoPath - require.NoError(t, cmd.Run()) - - worktreePath := filepath.Join(tmp, "workspace", "test-repo") - err := createWorktree(t.Context(), repoPath, worktreePath, "issue/456") - require.NoError(t, err) - - assert.DirExists(t, worktreePath) -} - -func TestCheckWorkspaceSafety_Clean(t *testing.T) { - tmp := t.TempDir() - wsPath := filepath.Join(tmp, "workspace") - require.NoError(t, os.MkdirAll(wsPath, 0755)) - - repoPath := setupTestRepo(t, tmp, "origin-repo") - worktreePath := filepath.Join(wsPath, "origin-repo") - require.NoError(t, createWorktree(t.Context(), repoPath, worktreePath, "test-branch")) - - dirty, reasons := checkWorkspaceSafety(wsPath) - assert.False(t, dirty) - assert.Empty(t, reasons) -} - -func TestCheckWorkspaceSafety_Dirty(t *testing.T) { - tmp := t.TempDir() - wsPath := filepath.Join(tmp, "workspace") - require.NoError(t, os.MkdirAll(wsPath, 0755)) - - repoPath := setupTestRepo(t, tmp, "origin-repo") - worktreePath := filepath.Join(wsPath, "origin-repo") - require.NoError(t, createWorktree(t.Context(), repoPath, worktreePath, "test-branch")) - - // Create uncommitted file - require.NoError(t, os.WriteFile(filepath.Join(worktreePath, "dirty.txt"), []byte("dirty"), 0644)) - - dirty, reasons := checkWorkspaceSafety(wsPath) - assert.True(t, dirty) - assert.Contains(t, reasons[0], "uncommitted changes") -} - -func TestEpicBranchName(t *testing.T) { - assert.Equal(t, "epic/101", epicBranchName(101)) - assert.Equal(t, "epic/42", epicBranchName(42)) -} - -func trimNL(s string) string { - return s[:len(s)-1] -} diff --git a/cmd/workspace/cmd_workspace.go b/cmd/workspace/cmd_workspace.go deleted file mode 100644 index c897371..0000000 --- a/cmd/workspace/cmd_workspace.go +++ /dev/null @@ -1,90 +0,0 @@ -package workspace - -import ( - "strings" - - "forge.lthn.ai/core/cli/pkg/cli" -) - -// AddWorkspaceCommands registers workspace management commands. -func AddWorkspaceCommands(root *cli.Command) { - wsCmd := &cli.Command{ - Use: "workspace", - Short: "Manage workspace configuration", - RunE: runWorkspaceInfo, - } - - wsCmd.AddCommand(&cli.Command{ - Use: "active [package]", - Short: "Show or set the active package", - RunE: runWorkspaceActive, - }) - - addTaskCommands(wsCmd) - addPrepCommands(wsCmd) - - root.AddCommand(wsCmd) -} - -func runWorkspaceInfo(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - config, err := LoadConfig(root) - if err != nil { - return err - } - if config == nil { - return cli.Err("workspace config not found") - } - - cli.Print("Active: %s\n", cli.ValueStyle.Render(config.Active)) - cli.Print("Packages: %s\n", cli.DimStyle.Render(config.PackagesDir)) - if len(config.DefaultOnly) > 0 { - cli.Print("Types: %s\n", cli.DimStyle.Render(strings.Join(config.DefaultOnly, ", "))) - } - - return nil -} - -func runWorkspaceActive(cmd *cli.Command, args []string) error { - root, err := FindWorkspaceRoot() - if err != nil { - return cli.Err("not in a workspace") - } - - config, err := LoadConfig(root) - if err != nil { - return err - } - if config == nil { - config = DefaultConfig() - } - - // If no args, show active - if len(args) == 0 { - if config.Active == "" { - cli.Println("No active package set") - return nil - } - cli.Text(config.Active) - return nil - } - - // Set active - target := args[0] - if target == config.Active { - cli.Print("Active package is already %s\n", cli.ValueStyle.Render(target)) - return nil - } - - config.Active = target - if err := SaveConfig(root, config); err != nil { - return err - } - - cli.Print("Active package set to %s\n", cli.SuccessStyle.Render(target)) - return nil -} diff --git a/cmd/workspace/config.go b/cmd/workspace/config.go deleted file mode 100644 index bc9010f..0000000 --- a/cmd/workspace/config.go +++ /dev/null @@ -1,103 +0,0 @@ -package workspace - -import ( - "os" - "path/filepath" - - coreio "forge.lthn.ai/core/go-io" - coreerr "forge.lthn.ai/core/go-log" - "gopkg.in/yaml.v3" -) - -// WorkspaceConfig holds workspace-level configuration from .core/workspace.yaml. -type WorkspaceConfig struct { - Version int `yaml:"version"` - Active string `yaml:"active"` // Active package name - DefaultOnly []string `yaml:"default_only"` // Default types for setup - PackagesDir string `yaml:"packages_dir"` // Where packages are cloned -} - -// DefaultConfig returns a config with default values. -func DefaultConfig() *WorkspaceConfig { - return &WorkspaceConfig{ - Version: 1, - PackagesDir: "./packages", - } -} - -// LoadConfig tries to load workspace.yaml from the given directory's .core subfolder. -// Returns nil if no config file exists (caller should check for nil). -func LoadConfig(dir string) (*WorkspaceConfig, error) { - path := filepath.Join(dir, ".core", "workspace.yaml") - data, err := coreio.Local.Read(path) - if err != nil { - // If using Local.Read, it returns error on not found. - // We can check if file exists first or handle specific error if exposed. - // Simplest is to check existence first or assume IsNotExist. - // Since we don't have easy IsNotExist check on coreio error returned yet (uses wrapped error), - // let's check IsFile first. - if !coreio.Local.IsFile(path) { - // Try parent directory - parent := filepath.Dir(dir) - if parent != dir { - return LoadConfig(parent) - } - // No workspace.yaml found anywhere - return nil to indicate no config - return nil, nil - } - return nil, coreerr.E("LoadConfig", "failed to read workspace config", err) - } - - config := DefaultConfig() - if err := yaml.Unmarshal([]byte(data), config); err != nil { - return nil, coreerr.E("LoadConfig", "failed to parse workspace config", err) - } - - if config.Version != 1 { - return nil, coreerr.E("LoadConfig", "unsupported workspace config version", nil) - } - - return config, nil -} - -// SaveConfig saves the configuration to the given directory's .core/workspace.yaml. -func SaveConfig(dir string, config *WorkspaceConfig) error { - coreDir := filepath.Join(dir, ".core") - if err := coreio.Local.EnsureDir(coreDir); err != nil { - return coreerr.E("SaveConfig", "failed to create .core directory", err) - } - - path := filepath.Join(coreDir, "workspace.yaml") - data, err := yaml.Marshal(config) - if err != nil { - return coreerr.E("SaveConfig", "failed to marshal workspace config", err) - } - - if err := coreio.Local.Write(path, string(data)); err != nil { - return coreerr.E("SaveConfig", "failed to write workspace config", err) - } - - return nil -} - -// FindWorkspaceRoot searches for the root directory containing .core/workspace.yaml. -func FindWorkspaceRoot() (string, error) { - dir, err := os.Getwd() - if err != nil { - return "", err - } - - for { - if coreio.Local.IsFile(filepath.Join(dir, ".core", "workspace.yaml")) { - return dir, nil - } - - parent := filepath.Dir(dir) - if parent == dir { - break - } - dir = parent - } - - return "", coreerr.E("FindWorkspaceRoot", "not in a workspace", nil) -} diff --git a/pkg/jobrunner/coverage_boost_test.go b/pkg/jobrunner/coverage_boost_test.go deleted file mode 100644 index ea44256..0000000 --- a/pkg/jobrunner/coverage_boost_test.go +++ /dev/null @@ -1,395 +0,0 @@ -package jobrunner - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Journal: NewJournal error path --- - -func TestNewJournal_Bad_EmptyBaseDir(t *testing.T) { - _, err := NewJournal("") - require.Error(t, err) - assert.Contains(t, err.Error(), "base directory is required") -} - -func TestNewJournal_Good(t *testing.T) { - dir := t.TempDir() - j, err := NewJournal(dir) - require.NoError(t, err) - assert.NotNil(t, j) -} - -// --- Journal: sanitizePathComponent additional cases --- - -func TestSanitizePathComponent_Good_ValidNames(t *testing.T) { - tests := []struct { - input string - want string - }{ - {"host-uk", "host-uk"}, - {"core", "core"}, - {"my_repo", "my_repo"}, - {"repo.v2", "repo.v2"}, - {"A123", "A123"}, - } - - for _, tc := range tests { - got, err := sanitizePathComponent(tc.input) - require.NoError(t, err, "input: %q", tc.input) - assert.Equal(t, tc.want, got) - } -} - -func TestSanitizePathComponent_Bad_Invalid(t *testing.T) { - tests := []struct { - name string - input string - }{ - {"empty", ""}, - {"spaces", " "}, - {"dotdot", ".."}, - {"dot", "."}, - {"slash", "foo/bar"}, - {"backslash", `foo\bar`}, - {"special", "org$bad"}, - {"leading-dot", ".hidden"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - _, err := sanitizePathComponent(tc.input) - assert.Error(t, err, "input: %q", tc.input) - }) - } -} - -// --- Journal: Append with readonly directory --- - -func TestJournal_Append_Bad_ReadonlyDir(t *testing.T) { - if os.Getuid() == 0 { - t.Skip("chmod does not restrict root") - } - // Create a dir that we then make readonly (only works as non-root). - dir := t.TempDir() - readonlyDir := filepath.Join(dir, "readonly") - require.NoError(t, os.MkdirAll(readonlyDir, 0o755)) - require.NoError(t, os.Chmod(readonlyDir, 0o444)) - t.Cleanup(func() { _ = os.Chmod(readonlyDir, 0o755) }) - - j, err := NewJournal(readonlyDir) - require.NoError(t, err) - - signal := &PipelineSignal{ - RepoOwner: "test-owner", - RepoName: "test-repo", - } - result := &ActionResult{ - Action: "test", - Timestamp: time.Now(), - } - - err = j.Append(signal, result) - // Should fail because MkdirAll cannot create subdirectories in readonly dir. - assert.Error(t, err) -} - -// --- Poller: error-returning source --- - -type errorSource struct { - name string -} - -func (e *errorSource) Name() string { return e.name } -func (e *errorSource) Poll(_ context.Context) ([]*PipelineSignal, error) { - return nil, fmt.Errorf("poll error") -} -func (e *errorSource) Report(_ context.Context, _ *ActionResult) error { return nil } - -func TestPoller_RunOnce_Good_SourceError(t *testing.T) { - src := &errorSource{name: "broken-source"} - handler := &mockHandler{name: "test"} - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) // Poll errors are logged, not returned - - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Empty(t, handler.executed, "handler should not be called when poll fails") -} - -// --- Poller: error-returning handler --- - -type errorHandler struct { - name string -} - -func (e *errorHandler) Name() string { return e.name } -func (e *errorHandler) Match(_ *PipelineSignal) bool { return true } -func (e *errorHandler) Execute(_ context.Context, _ *PipelineSignal) (*ActionResult, error) { - return nil, fmt.Errorf("handler error") -} - -func TestPoller_RunOnce_Good_HandlerError(t *testing.T) { - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 1, - PRNumber: 1, - RepoOwner: "test", - RepoName: "repo", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &errorHandler{name: "broken-handler"} - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) // Handler errors are logged, not returned - - // Source should not have received a report (handler errored out). - src.mu.Lock() - defer src.mu.Unlock() - assert.Empty(t, src.reports) -} - -// --- Poller: with Journal integration --- - -func TestPoller_RunOnce_Good_WithJournal(t *testing.T) { - dir := t.TempDir() - journal, err := NewJournal(dir) - require.NoError(t, err) - - sig := &PipelineSignal{ - EpicNumber: 10, - ChildNumber: 3, - PRNumber: 55, - RepoOwner: "host-uk", - RepoName: "core", - PRState: "OPEN", - CheckStatus: "SUCCESS", - Mergeable: "MERGEABLE", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "test-handler", - matchFn: func(s *PipelineSignal) bool { - return true - }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - Journal: journal, - }) - - err = p.RunOnce(context.Background()) - require.NoError(t, err) - - handler.mu.Lock() - require.Len(t, handler.executed, 1) - handler.mu.Unlock() - - // Verify the journal file was written. - date := time.Now().UTC().Format("2006-01-02") - journalPath := filepath.Join(dir, "host-uk", "core", date+".jsonl") - _, statErr := os.Stat(journalPath) - assert.NoError(t, statErr, "journal file should exist at %s", journalPath) -} - -// --- Poller: error-returning Report --- - -type reportErrorSource struct { - name string - signals []*PipelineSignal - mu sync.Mutex -} - -func (r *reportErrorSource) Name() string { return r.name } -func (r *reportErrorSource) Poll(_ context.Context) ([]*PipelineSignal, error) { - r.mu.Lock() - defer r.mu.Unlock() - return r.signals, nil -} -func (r *reportErrorSource) Report(_ context.Context, _ *ActionResult) error { - return fmt.Errorf("report error") -} - -func TestPoller_RunOnce_Good_ReportError(t *testing.T) { - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 1, - PRNumber: 1, - RepoOwner: "test", - RepoName: "repo", - } - - src := &reportErrorSource{ - name: "report-fail-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "test-handler", - matchFn: func(s *PipelineSignal) bool { return true }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) // Report errors are logged, not returned - - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Len(t, handler.executed, 1, "handler should still execute even though report fails") -} - -// --- Poller: multiple sources and handlers --- - -func TestPoller_RunOnce_Good_MultipleSources(t *testing.T) { - sig1 := &PipelineSignal{ - EpicNumber: 1, ChildNumber: 1, PRNumber: 1, - RepoOwner: "org1", RepoName: "repo1", - } - sig2 := &PipelineSignal{ - EpicNumber: 2, ChildNumber: 2, PRNumber: 2, - RepoOwner: "org2", RepoName: "repo2", - } - - src1 := &mockSource{name: "source-1", signals: []*PipelineSignal{sig1}} - src2 := &mockSource{name: "source-2", signals: []*PipelineSignal{sig2}} - - handler := &mockHandler{ - name: "catch-all", - matchFn: func(s *PipelineSignal) bool { return true }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src1, src2}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Len(t, handler.executed, 2) -} - -// --- Poller: Run with immediate cancellation --- - -func TestPoller_Run_Good_ImmediateCancel(t *testing.T) { - src := &mockSource{name: "source", signals: nil} - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - PollInterval: 1 * time.Hour, // long interval - }) - - ctx, cancel := context.WithCancel(context.Background()) - // Cancel after the first RunOnce completes. - go func() { - time.Sleep(50 * time.Millisecond) - cancel() - }() - - err := p.Run(ctx) - assert.ErrorIs(t, err, context.Canceled) - assert.Equal(t, 1, p.Cycle()) // One cycle from the initial RunOnce -} - -// --- Journal: Append with journal error logging --- - -func TestPoller_RunOnce_Good_JournalAppendError(t *testing.T) { - if os.Getuid() == 0 { - t.Skip("chmod does not restrict root") - } - // Use a directory that will cause journal writes to fail. - dir := t.TempDir() - journal, err := NewJournal(dir) - require.NoError(t, err) - - // Make the journal directory read-only to trigger append errors. - require.NoError(t, os.Chmod(dir, 0o444)) - t.Cleanup(func() { _ = os.Chmod(dir, 0o755) }) - - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 1, - PRNumber: 1, - RepoOwner: "test", - RepoName: "repo", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "test-handler", - matchFn: func(s *PipelineSignal) bool { return true }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - Journal: journal, - }) - - err = p.RunOnce(context.Background()) - // Journal errors are logged, not returned. - require.NoError(t, err) - - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Len(t, handler.executed, 1, "handler should still execute even when journal fails") -} - -// --- Poller: Cycle counter increments --- - -func TestPoller_Cycle_Good_Increments(t *testing.T) { - src := &mockSource{name: "source", signals: nil} - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - }) - - assert.Equal(t, 0, p.Cycle()) - - _ = p.RunOnce(context.Background()) - assert.Equal(t, 1, p.Cycle()) - - _ = p.RunOnce(context.Background()) - assert.Equal(t, 2, p.Cycle()) -} diff --git a/pkg/jobrunner/forgejo/signals.go b/pkg/jobrunner/forgejo/signals.go deleted file mode 100644 index 9d720e4..0000000 --- a/pkg/jobrunner/forgejo/signals.go +++ /dev/null @@ -1,114 +0,0 @@ -package forgejo - -import ( - "regexp" - "strconv" - - forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// epicChildRe matches checklist items: - [ ] #42 or - [x] #42 -var epicChildRe = regexp.MustCompile(`- \[([ x])\] #(\d+)`) - -// parseEpicChildren extracts child issue numbers from an epic body's checklist. -func parseEpicChildren(body string) (unchecked []int, checked []int) { - matches := epicChildRe.FindAllStringSubmatch(body, -1) - for _, m := range matches { - num, err := strconv.Atoi(m[2]) - if err != nil { - continue - } - if m[1] == "x" { - checked = append(checked, num) - } else { - unchecked = append(unchecked, num) - } - } - return unchecked, checked -} - -// linkedPRRe matches "#N" references in PR bodies. -var linkedPRRe = regexp.MustCompile(`#(\d+)`) - -// findLinkedPR finds the first PR whose body references the given issue number. -func findLinkedPR(prs []*forgejosdk.PullRequest, issueNumber int) *forgejosdk.PullRequest { - target := strconv.Itoa(issueNumber) - for _, pr := range prs { - matches := linkedPRRe.FindAllStringSubmatch(pr.Body, -1) - for _, m := range matches { - if m[1] == target { - return pr - } - } - } - return nil -} - -// mapPRState maps Forgejo's PR state and merged flag to a canonical string. -func mapPRState(pr *forgejosdk.PullRequest) string { - if pr.HasMerged { - return "MERGED" - } - switch pr.State { - case forgejosdk.StateOpen: - return "OPEN" - case forgejosdk.StateClosed: - return "CLOSED" - default: - return "CLOSED" - } -} - -// mapMergeable maps Forgejo's boolean Mergeable field to a canonical string. -func mapMergeable(pr *forgejosdk.PullRequest) string { - if pr.HasMerged { - return "UNKNOWN" - } - if pr.Mergeable { - return "MERGEABLE" - } - return "CONFLICTING" -} - -// mapCombinedStatus maps a Forgejo CombinedStatus to SUCCESS/FAILURE/PENDING. -func mapCombinedStatus(cs *forgejosdk.CombinedStatus) string { - if cs == nil || cs.TotalCount == 0 { - return "PENDING" - } - switch cs.State { - case forgejosdk.StatusSuccess: - return "SUCCESS" - case forgejosdk.StatusFailure, forgejosdk.StatusError: - return "FAILURE" - default: - return "PENDING" - } -} - -// buildSignal creates a PipelineSignal from Forgejo API data. -func buildSignal( - owner, repo string, - epicNumber, childNumber int, - pr *forgejosdk.PullRequest, - checkStatus string, -) *jobrunner.PipelineSignal { - sig := &jobrunner.PipelineSignal{ - EpicNumber: epicNumber, - ChildNumber: childNumber, - PRNumber: int(pr.Index), - RepoOwner: owner, - RepoName: repo, - PRState: mapPRState(pr), - IsDraft: false, // SDK v2.2.0 doesn't expose Draft; treat as non-draft - Mergeable: mapMergeable(pr), - CheckStatus: checkStatus, - } - - if pr.Head != nil { - sig.LastCommitSHA = pr.Head.Sha - } - - return sig -} diff --git a/pkg/jobrunner/forgejo/signals_test.go b/pkg/jobrunner/forgejo/signals_test.go deleted file mode 100644 index 4b72535..0000000 --- a/pkg/jobrunner/forgejo/signals_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package forgejo - -import ( - "testing" - - forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - "github.com/stretchr/testify/assert" -) - -func TestMapPRState_Good_Open(t *testing.T) { - pr := &forgejosdk.PullRequest{State: forgejosdk.StateOpen, HasMerged: false} - assert.Equal(t, "OPEN", mapPRState(pr)) -} - -func TestMapPRState_Good_Merged(t *testing.T) { - pr := &forgejosdk.PullRequest{State: forgejosdk.StateClosed, HasMerged: true} - assert.Equal(t, "MERGED", mapPRState(pr)) -} - -func TestMapPRState_Good_Closed(t *testing.T) { - pr := &forgejosdk.PullRequest{State: forgejosdk.StateClosed, HasMerged: false} - assert.Equal(t, "CLOSED", mapPRState(pr)) -} - -func TestMapPRState_Good_UnknownState(t *testing.T) { - // Any unknown state should default to CLOSED. - pr := &forgejosdk.PullRequest{State: "weird", HasMerged: false} - assert.Equal(t, "CLOSED", mapPRState(pr)) -} - -func TestMapMergeable_Good_Mergeable(t *testing.T) { - pr := &forgejosdk.PullRequest{Mergeable: true, HasMerged: false} - assert.Equal(t, "MERGEABLE", mapMergeable(pr)) -} - -func TestMapMergeable_Good_Conflicting(t *testing.T) { - pr := &forgejosdk.PullRequest{Mergeable: false, HasMerged: false} - assert.Equal(t, "CONFLICTING", mapMergeable(pr)) -} - -func TestMapMergeable_Good_Merged(t *testing.T) { - pr := &forgejosdk.PullRequest{HasMerged: true} - assert.Equal(t, "UNKNOWN", mapMergeable(pr)) -} - -func TestMapCombinedStatus_Good_Success(t *testing.T) { - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusSuccess, - TotalCount: 1, - } - assert.Equal(t, "SUCCESS", mapCombinedStatus(cs)) -} - -func TestMapCombinedStatus_Good_Failure(t *testing.T) { - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusFailure, - TotalCount: 1, - } - assert.Equal(t, "FAILURE", mapCombinedStatus(cs)) -} - -func TestMapCombinedStatus_Good_Error(t *testing.T) { - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusError, - TotalCount: 1, - } - assert.Equal(t, "FAILURE", mapCombinedStatus(cs)) -} - -func TestMapCombinedStatus_Good_Pending(t *testing.T) { - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusPending, - TotalCount: 1, - } - assert.Equal(t, "PENDING", mapCombinedStatus(cs)) -} - -func TestMapCombinedStatus_Good_Nil(t *testing.T) { - assert.Equal(t, "PENDING", mapCombinedStatus(nil)) -} - -func TestMapCombinedStatus_Good_ZeroCount(t *testing.T) { - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusSuccess, - TotalCount: 0, - } - assert.Equal(t, "PENDING", mapCombinedStatus(cs)) -} - -func TestParseEpicChildren_Good_Mixed(t *testing.T) { - body := "## Sprint\n- [x] #1\n- [ ] #2\n- [x] #3\n- [ ] #4\nSome text\n" - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{2, 4}, unchecked) - assert.Equal(t, []int{1, 3}, checked) -} - -func TestParseEpicChildren_Good_NoCheckboxes(t *testing.T) { - body := "This is just a normal issue with no checkboxes." - unchecked, checked := parseEpicChildren(body) - assert.Nil(t, unchecked) - assert.Nil(t, checked) -} - -func TestParseEpicChildren_Good_AllChecked(t *testing.T) { - body := "- [x] #10\n- [x] #20\n" - unchecked, checked := parseEpicChildren(body) - assert.Nil(t, unchecked) - assert.Equal(t, []int{10, 20}, checked) -} - -func TestParseEpicChildren_Good_AllUnchecked(t *testing.T) { - body := "- [ ] #5\n- [ ] #6\n" - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{5, 6}, unchecked) - assert.Nil(t, checked) -} - -func TestFindLinkedPR_Good(t *testing.T) { - prs := []*forgejosdk.PullRequest{ - {Index: 10, Body: "Fixes #5"}, - {Index: 11, Body: "Resolves #7"}, - {Index: 12, Body: "Nothing here"}, - } - - pr := findLinkedPR(prs, 7) - assert.NotNil(t, pr) - assert.Equal(t, int64(11), pr.Index) -} - -func TestFindLinkedPR_Good_NotFound(t *testing.T) { - prs := []*forgejosdk.PullRequest{ - {Index: 10, Body: "Fixes #5"}, - } - pr := findLinkedPR(prs, 99) - assert.Nil(t, pr) -} - -func TestFindLinkedPR_Good_Nil(t *testing.T) { - pr := findLinkedPR(nil, 1) - assert.Nil(t, pr) -} - -func TestBuildSignal_Good(t *testing.T) { - pr := &forgejosdk.PullRequest{ - Index: 42, - State: forgejosdk.StateOpen, - Mergeable: true, - Head: &forgejosdk.PRBranchInfo{Sha: "deadbeef"}, - } - - sig := buildSignal("org", "repo", 10, 5, pr, "SUCCESS") - - assert.Equal(t, 10, sig.EpicNumber) - assert.Equal(t, 5, sig.ChildNumber) - assert.Equal(t, 42, sig.PRNumber) - assert.Equal(t, "org", sig.RepoOwner) - assert.Equal(t, "repo", sig.RepoName) - assert.Equal(t, "OPEN", sig.PRState) - assert.Equal(t, "MERGEABLE", sig.Mergeable) - assert.Equal(t, "SUCCESS", sig.CheckStatus) - assert.Equal(t, "deadbeef", sig.LastCommitSHA) - assert.False(t, sig.IsDraft) -} - -func TestBuildSignal_Good_NilHead(t *testing.T) { - pr := &forgejosdk.PullRequest{ - Index: 1, - State: forgejosdk.StateClosed, - HasMerged: true, - } - - sig := buildSignal("org", "repo", 1, 2, pr, "PENDING") - assert.Equal(t, "", sig.LastCommitSHA) - assert.Equal(t, "MERGED", sig.PRState) -} - -func TestSplitRepo_Good(t *testing.T) { - tests := []struct { - input string - owner string - repo string - err bool - }{ - {"host-uk/core", "host-uk", "core", false}, - {"a/b", "a", "b", false}, - {"org/repo-name", "org", "repo-name", false}, - {"invalid", "", "", true}, - {"", "", "", true}, - {"/repo", "", "", true}, - {"owner/", "", "", true}, - } - - for _, tt := range tests { - t.Run(tt.input, func(t *testing.T) { - owner, repo, err := splitRepo(tt.input) - if tt.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.owner, owner) - assert.Equal(t, tt.repo, repo) - } - }) - } -} diff --git a/pkg/jobrunner/forgejo/source.go b/pkg/jobrunner/forgejo/source.go deleted file mode 100644 index da0dddd..0000000 --- a/pkg/jobrunner/forgejo/source.go +++ /dev/null @@ -1,173 +0,0 @@ -package forgejo - -import ( - "context" - "fmt" - "strings" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" - "forge.lthn.ai/core/go-log" -) - -// Config configures a ForgejoSource. -type Config struct { - Repos []string // "owner/repo" format -} - -// ForgejoSource polls a Forgejo instance for pipeline signals from epic issues. -type ForgejoSource struct { - repos []string - forge *forge.Client -} - -// New creates a ForgejoSource using the given forge client. -func New(cfg Config, client *forge.Client) *ForgejoSource { - return &ForgejoSource{ - repos: cfg.Repos, - forge: client, - } -} - -// Name returns the source identifier. -func (s *ForgejoSource) Name() string { - return "forgejo" -} - -// Poll fetches epics and their linked PRs from all configured repositories, -// returning a PipelineSignal for each unchecked child that has a linked PR. -func (s *ForgejoSource) Poll(ctx context.Context) ([]*jobrunner.PipelineSignal, error) { - var signals []*jobrunner.PipelineSignal - - for _, repoFull := range s.repos { - owner, repo, err := splitRepo(repoFull) - if err != nil { - log.Error("invalid repo format", "repo", repoFull, "err", err) - continue - } - - repoSignals, err := s.pollRepo(ctx, owner, repo) - if err != nil { - log.Error("poll repo failed", "repo", repoFull, "err", err) - continue - } - - signals = append(signals, repoSignals...) - } - - return signals, nil -} - -// Report posts the action result as a comment on the epic issue. -func (s *ForgejoSource) Report(ctx context.Context, result *jobrunner.ActionResult) error { - if result == nil { - return nil - } - - status := "succeeded" - if !result.Success { - status = "failed" - } - - body := fmt.Sprintf("**jobrunner** `%s` %s for #%d (PR #%d)", result.Action, status, result.ChildNumber, result.PRNumber) - if result.Error != "" { - body += fmt.Sprintf("\n\n```\n%s\n```", result.Error) - } - - return s.forge.CreateIssueComment(result.RepoOwner, result.RepoName, int64(result.EpicNumber), body) -} - -// pollRepo fetches epics and PRs for a single repository. -func (s *ForgejoSource) pollRepo(_ context.Context, owner, repo string) ([]*jobrunner.PipelineSignal, error) { - // Fetch epic issues (label=epic, state=open). - issues, err := s.forge.ListIssues(owner, repo, forge.ListIssuesOpts{State: "open"}) - if err != nil { - return nil, log.E("forgejo.pollRepo", "fetch issues", err) - } - - // Filter to epics only. - var epics []epicInfo - for _, issue := range issues { - for _, label := range issue.Labels { - if label.Name == "epic" { - epics = append(epics, epicInfo{ - Number: int(issue.Index), - Body: issue.Body, - }) - break - } - } - } - - if len(epics) == 0 { - return nil, nil - } - - // Fetch all open PRs (and also merged/closed to catch MERGED state). - prs, err := s.forge.ListPullRequests(owner, repo, "all") - if err != nil { - return nil, log.E("forgejo.pollRepo", "fetch PRs", err) - } - - var signals []*jobrunner.PipelineSignal - - for _, epic := range epics { - unchecked, _ := parseEpicChildren(epic.Body) - for _, childNum := range unchecked { - pr := findLinkedPR(prs, childNum) - - if pr == nil { - // No PR yet — check if the child issue is assigned (needs coding). - childIssue, err := s.forge.GetIssue(owner, repo, int64(childNum)) - if err != nil { - log.Error("fetch child issue failed", "repo", owner+"/"+repo, "issue", childNum, "err", err) - continue - } - if len(childIssue.Assignees) > 0 && childIssue.Assignees[0].UserName != "" { - sig := &jobrunner.PipelineSignal{ - EpicNumber: epic.Number, - ChildNumber: childNum, - RepoOwner: owner, - RepoName: repo, - NeedsCoding: true, - Assignee: childIssue.Assignees[0].UserName, - IssueTitle: childIssue.Title, - IssueBody: childIssue.Body, - } - signals = append(signals, sig) - } - continue - } - - // Get combined commit status for the PR's head SHA. - checkStatus := "PENDING" - if pr.Head != nil && pr.Head.Sha != "" { - cs, err := s.forge.GetCombinedStatus(owner, repo, pr.Head.Sha) - if err != nil { - log.Error("fetch combined status failed", "repo", owner+"/"+repo, "sha", pr.Head.Sha, "err", err) - } else { - checkStatus = mapCombinedStatus(cs) - } - } - - sig := buildSignal(owner, repo, epic.Number, childNum, pr, checkStatus) - signals = append(signals, sig) - } - } - - return signals, nil -} - -type epicInfo struct { - Number int - Body string -} - -// splitRepo parses "owner/repo" into its components. -func splitRepo(full string) (string, string, error) { - parts := strings.SplitN(full, "/", 2) - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - return "", "", log.E("forgejo.splitRepo", fmt.Sprintf("expected owner/repo format, got %q", full), nil) - } - return parts[0], parts[1], nil -} diff --git a/pkg/jobrunner/forgejo/source_extra_test.go b/pkg/jobrunner/forgejo/source_extra_test.go deleted file mode 100644 index c70745f..0000000 --- a/pkg/jobrunner/forgejo/source_extra_test.go +++ /dev/null @@ -1,320 +0,0 @@ -package forgejo - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestForgejoSource_Poll_Good_InvalidRepo(t *testing.T) { - // Invalid repo format should be logged and skipped, not error. - s := New(Config{Repos: []string{"invalid-no-slash"}}, nil) - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals) -} - -func TestForgejoSource_Poll_Good_MultipleRepos(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - // Return one epic per repo. - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 10, - "body": "Fixes #2", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "abc", "ref": "fix", "label": "fix"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", - "total_count": 1, - "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org-a/repo-1", "org-b/repo-2"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Len(t, signals, 2) -} - -func TestForgejoSource_Poll_Good_NeedsCoding(t *testing.T) { - // When a child issue has no linked PR but is assigned, NeedsCoding should be true. - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues/5"): - // Child issue with assignee. - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 5, - "title": "Implement feature", - "body": "Please implement this.", - "state": "open", - "assignees": []map[string]any{{"login": "darbs-claude", "username": "darbs-claude"}}, - }) - - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #5\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - // No PRs linked. - _ = json.NewEncoder(w).Encode([]any{}) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"test-org/test-repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - sig := signals[0] - assert.True(t, sig.NeedsCoding) - assert.Equal(t, "darbs-claude", sig.Assignee) - assert.Equal(t, "Implement feature", sig.IssueTitle) - assert.Equal(t, "Please implement this.", sig.IssueBody) - assert.Equal(t, 5, sig.ChildNumber) -} - -func TestForgejoSource_Poll_Good_MergedPR(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #3\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 20, - "body": "Fixes #3", - "state": "closed", - "mergeable": false, - "merged": true, - "head": map[string]string{"sha": "merged123", "ref": "fix", "label": "fix"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", - "total_count": 1, - "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - assert.Equal(t, "MERGED", signals[0].PRState) - assert.Equal(t, "UNKNOWN", signals[0].Mergeable) -} - -func TestForgejoSource_Poll_Good_NoHeadSHA(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #3\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 20, - "body": "Fixes #3", - "state": "open", - "mergeable": true, - "merged": false, - // No head field. - }, - } - _ = json.NewEncoder(w).Encode(prs) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - // Without head SHA, check status stays PENDING. - assert.Equal(t, "PENDING", signals[0].CheckStatus) -} - -func TestForgejoSource_Report_Good_Nil(t *testing.T) { - s := New(Config{}, nil) - err := s.Report(context.Background(), nil) - assert.NoError(t, err) -} - -func TestForgejoSource_Report_Good_Failed(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - result := &jobrunner.ActionResult{ - Action: "dispatch", - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 1, - ChildNumber: 2, - PRNumber: 3, - Success: false, - Error: "transfer failed", - } - - err := s.Report(context.Background(), result) - require.NoError(t, err) - assert.Contains(t, capturedBody, "failed") - assert.Contains(t, capturedBody, "transfer failed") -} - -func TestForgejoSource_Poll_Good_APIErrors(t *testing.T) { - // When the issues API fails, poll should continue with other repos. - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals) -} - -func TestForgejoSource_Poll_Good_EmptyRepos(t *testing.T) { - s := New(Config{Repos: []string{}}, nil) - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals) -} - -func TestForgejoSource_Poll_Good_NonEpicIssues(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - // Issues without the "epic" label. - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "bug"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals, "non-epic issues should not generate signals") -} diff --git a/pkg/jobrunner/forgejo/source_phase3_test.go b/pkg/jobrunner/forgejo/source_phase3_test.go deleted file mode 100644 index a06d5fa..0000000 --- a/pkg/jobrunner/forgejo/source_phase3_test.go +++ /dev/null @@ -1,672 +0,0 @@ -package forgejo - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// --- Signal parsing and filtering tests --- - -func TestParseEpicChildren_Good_EmptyBody(t *testing.T) { - unchecked, checked := parseEpicChildren("") - assert.Nil(t, unchecked) - assert.Nil(t, checked) -} - -func TestParseEpicChildren_Good_MixedContent(t *testing.T) { - // Checkboxes mixed with regular markdown content. - body := `## Epic: Refactor Auth - -Some description of the epic. - -### Tasks -- [x] #10 — Migrate session store -- [ ] #11 — Update OAuth flow -- [x] #12 — Fix token refresh -- [ ] #13 — Add 2FA support - -### Notes -This is a note, not a checkbox. -- Regular list item -- Another item -` - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{11, 13}, unchecked) - assert.Equal(t, []int{10, 12}, checked) -} - -func TestParseEpicChildren_Good_LargeIssueNumbers(t *testing.T) { - body := "- [ ] #9999\n- [x] #10000\n" - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{9999}, unchecked) - assert.Equal(t, []int{10000}, checked) -} - -func TestParseEpicChildren_Good_ConsecutiveCheckboxes(t *testing.T) { - body := "- [ ] #1\n- [ ] #2\n- [ ] #3\n- [ ] #4\n- [ ] #5\n" - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{1, 2, 3, 4, 5}, unchecked) - assert.Nil(t, checked) -} - -// --- findLinkedPR tests --- - -func TestFindLinkedPR_Good_MultipleReferencesInBody(t *testing.T) { - prs := []*forgejosdk.PullRequest{ - {Index: 10, Body: "Fixes #5 and relates to #7"}, - {Index: 11, Body: "Closes #8"}, - } - - // Should find PR #10 because it references #7. - pr := findLinkedPR(prs, 7) - assert.NotNil(t, pr) - assert.Equal(t, int64(10), pr.Index) - - // Should find PR #10 because it references #5. - pr = findLinkedPR(prs, 5) - assert.NotNil(t, pr) - assert.Equal(t, int64(10), pr.Index) -} - -func TestFindLinkedPR_Good_EmptyBodyPR(t *testing.T) { - prs := []*forgejosdk.PullRequest{ - {Index: 10, Body: ""}, - {Index: 11, Body: "Fixes #7"}, - } - - pr := findLinkedPR(prs, 7) - assert.NotNil(t, pr) - assert.Equal(t, int64(11), pr.Index) -} - -func TestFindLinkedPR_Good_FirstMatchWins(t *testing.T) { - // Both PRs reference #7, first one should win. - prs := []*forgejosdk.PullRequest{ - {Index: 10, Body: "Fixes #7"}, - {Index: 11, Body: "Also fixes #7"}, - } - - pr := findLinkedPR(prs, 7) - assert.NotNil(t, pr) - assert.Equal(t, int64(10), pr.Index) -} - -func TestFindLinkedPR_Good_EmptySlice(t *testing.T) { - prs := []*forgejosdk.PullRequest{} - pr := findLinkedPR(prs, 1) - assert.Nil(t, pr) -} - -// --- mapPRState edge case --- - -func TestMapPRState_Good_MergedOverridesState(t *testing.T) { - // HasMerged=true should return MERGED regardless of State. - pr := &forgejosdk.PullRequest{State: forgejosdk.StateOpen, HasMerged: true} - assert.Equal(t, "MERGED", mapPRState(pr)) -} - -// --- mapCombinedStatus edge cases --- - -func TestMapCombinedStatus_Good_WarningState(t *testing.T) { - // Unknown/warning state should default to PENDING. - cs := &forgejosdk.CombinedStatus{ - State: forgejosdk.StatusWarning, - TotalCount: 1, - } - assert.Equal(t, "PENDING", mapCombinedStatus(cs)) -} - -// --- buildSignal edge cases --- - -func TestBuildSignal_Good_ClosedPR(t *testing.T) { - pr := &forgejosdk.PullRequest{ - Index: 5, - State: forgejosdk.StateClosed, - Mergeable: false, - HasMerged: false, - Head: &forgejosdk.PRBranchInfo{Sha: "abc"}, - } - - sig := buildSignal("org", "repo", 1, 2, pr, "FAILURE") - assert.Equal(t, "CLOSED", sig.PRState) - assert.Equal(t, "CONFLICTING", sig.Mergeable) - assert.Equal(t, "FAILURE", sig.CheckStatus) - assert.Equal(t, "abc", sig.LastCommitSHA) -} - -func TestBuildSignal_Good_MergedPR(t *testing.T) { - pr := &forgejosdk.PullRequest{ - Index: 99, - State: forgejosdk.StateClosed, - Mergeable: false, - HasMerged: true, - Head: &forgejosdk.PRBranchInfo{Sha: "merged123"}, - } - - sig := buildSignal("owner", "repo", 10, 5, pr, "SUCCESS") - assert.Equal(t, "MERGED", sig.PRState) - assert.Equal(t, "UNKNOWN", sig.Mergeable) - assert.Equal(t, 99, sig.PRNumber) - assert.Equal(t, "merged123", sig.LastCommitSHA) -} - -// --- splitRepo edge cases --- - -func TestSplitRepo_Bad_OnlySlash(t *testing.T) { - _, _, err := splitRepo("/") - assert.Error(t, err) -} - -func TestSplitRepo_Bad_MultipleSlashes(t *testing.T) { - // Should take only the first part as owner, rest as repo. - owner, repo, err := splitRepo("a/b/c") - require.NoError(t, err) - assert.Equal(t, "a", owner) - assert.Equal(t, "b/c", repo) -} - -// --- Poll with combined status failure --- - -func TestForgejoSource_Poll_Good_CombinedStatusFailure(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 10, - "body": "Fixes #2", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "fail123", "ref": "feature", "label": "feature"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - status := map[string]any{ - "state": "failure", - "total_count": 2, - "statuses": []map[string]any{{"status": "failure", "context": "ci"}}, - } - _ = json.NewEncoder(w).Encode(status) - - default: - w.WriteHeader(http.StatusNotFound) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - assert.Equal(t, "FAILURE", signals[0].CheckStatus) - assert.Equal(t, "OPEN", signals[0].PRState) - assert.Equal(t, "MERGEABLE", signals[0].Mergeable) -} - -// --- Poll with combined status error --- - -func TestForgejoSource_Poll_Good_CombinedStatusError(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #3\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 20, - "body": "Fixes #3", - "state": "open", - "mergeable": false, - "merged": false, - "head": map[string]string{"sha": "err123", "ref": "fix", "label": "fix"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - // Combined status endpoint returns 500 — should fall back to PENDING. - case strings.Contains(path, "/status"): - w.WriteHeader(http.StatusInternalServerError) - - default: - w.WriteHeader(http.StatusNotFound) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - // Combined status API error -> falls back to PENDING. - assert.Equal(t, "PENDING", signals[0].CheckStatus) - assert.Equal(t, "CONFLICTING", signals[0].Mergeable) -} - -// --- Poll with child that has no assignee (NeedsCoding path, no assignee) --- - -func TestForgejoSource_Poll_Good_ChildNoAssignee(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues/5"): - // Child issue with no assignee. - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 5, - "title": "Unassigned task", - "body": "No one is working on this.", - "state": "open", - "assignees": []map[string]any{}, - }) - - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #5\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - _ = json.NewEncoder(w).Encode([]any{}) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - // No signal should be emitted when child has no assignee and no PR. - assert.Empty(t, signals) -} - -// --- Poll with child issue fetch failure --- - -func TestForgejoSource_Poll_Good_ChildFetchFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues/5"): - // Child issue fetch fails. - w.WriteHeader(http.StatusInternalServerError) - - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #5\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - _ = json.NewEncoder(w).Encode([]any{}) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - // Child fetch error should be logged and skipped, not returned as error. - assert.Empty(t, signals) -} - -// --- Poll with multiple epics --- - -func TestForgejoSource_Poll_Good_MultipleEpics(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #3\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - { - "number": 2, - "body": "- [ ] #4\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 10, - "body": "Fixes #3", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "aaa", "ref": "f1", "label": "f1"}, - }, - { - "number": 11, - "body": "Fixes #4", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "bbb", "ref": "f2", "label": "f2"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", - "total_count": 1, - "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 2) - assert.Equal(t, 1, signals[0].EpicNumber) - assert.Equal(t, 3, signals[0].ChildNumber) - assert.Equal(t, 10, signals[0].PRNumber) - - assert.Equal(t, 2, signals[1].EpicNumber) - assert.Equal(t, 4, signals[1].ChildNumber) - assert.Equal(t, 11, signals[1].PRNumber) -} - -// --- Report with nil result --- - -func TestForgejoSource_Report_Good_NilResult(t *testing.T) { - s := New(Config{}, nil) - err := s.Report(context.Background(), nil) - assert.NoError(t, err) -} - -// --- Report constructs correct comment body --- - -func TestForgejoSource_Report_Good_SuccessFormat(t *testing.T) { - var capturedPath string - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedPath = r.URL.Path - w.Header().Set("Content-Type", "application/json") - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - result := &jobrunner.ActionResult{ - Action: "tick_parent", - RepoOwner: "core", - RepoName: "go-scm", - EpicNumber: 5, - ChildNumber: 10, - PRNumber: 20, - Success: true, - } - - err := s.Report(context.Background(), result) - require.NoError(t, err) - - // Comment should be on the epic issue. - assert.Contains(t, capturedPath, "/issues/5/comments") - assert.Contains(t, capturedBody, "tick_parent") - assert.Contains(t, capturedBody, "succeeded") - assert.Contains(t, capturedBody, "#10") - assert.Contains(t, capturedBody, "PR #20") -} - -func TestForgejoSource_Report_Good_FailureWithError(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - result := &jobrunner.ActionResult{ - Action: "enable_auto_merge", - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 1, - ChildNumber: 2, - PRNumber: 3, - Success: false, - Error: "merge conflict detected", - } - - err := s.Report(context.Background(), result) - require.NoError(t, err) - - assert.Contains(t, capturedBody, "failed") - assert.Contains(t, capturedBody, "merge conflict detected") -} - -// --- Poll filters only epic-labelled issues --- - -func TestForgejoSource_Poll_Good_MixedLabels(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}, {"name": "priority-high"}}, - "state": "open", - }, - { - "number": 3, - "body": "- [ ] #4\n", - "labels": []map[string]string{{"name": "bug"}}, - "state": "open", - }, - { - "number": 5, - "body": "- [ ] #6\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 10, - "body": "Fixes #2", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "sha1", "ref": "f1", "label": "f1"}, - }, - { - "number": 11, - "body": "Fixes #4", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "sha2", "ref": "f2", "label": "f2"}, - }, - { - "number": 12, - "body": "Fixes #6", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "sha3", "ref": "f3", "label": "f3"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", - "total_count": 1, - "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - // Only issues #1 and #5 have the "epic" label. - require.Len(t, signals, 2) - assert.Equal(t, 1, signals[0].EpicNumber) - assert.Equal(t, 2, signals[0].ChildNumber) - assert.Equal(t, 5, signals[1].EpicNumber) - assert.Equal(t, 6, signals[1].ChildNumber) -} - -// --- Poll with PRs error after issues succeed --- - -func TestForgejoSource_Poll_Good_PRsAPIError(t *testing.T) { - callCount := 0 - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - callCount++ - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, - "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - w.WriteHeader(http.StatusInternalServerError) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client, err := forge.New(srv.URL, "test-token") - require.NoError(t, err) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - // PR API failure -> repo is skipped, no signals. - assert.Empty(t, signals) -} - -// --- New creates source correctly --- - -func TestForgejoSource_New_Good(t *testing.T) { - s := New(Config{Repos: []string{"a/b", "c/d"}}, nil) - assert.Equal(t, "forgejo", s.Name()) - assert.Equal(t, []string{"a/b", "c/d"}, s.repos) -} diff --git a/pkg/jobrunner/forgejo/source_supplementary_test.go b/pkg/jobrunner/forgejo/source_supplementary_test.go deleted file mode 100644 index 7922dc6..0000000 --- a/pkg/jobrunner/forgejo/source_supplementary_test.go +++ /dev/null @@ -1,409 +0,0 @@ -package forgejo - -import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// --------------------------------------------------------------------------- -// Supplementary Forgejo signal source tests — extends Phase 3 coverage -// --------------------------------------------------------------------------- - -func TestForgejoSource_Poll_Good_MultipleEpicsMultipleChildren(t *testing.T) { - // Two epics, each with multiple unchecked children that have linked PRs. - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 10, - "body": "## Sprint\n- [ ] #11\n- [ ] #12\n- [x] #13\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - { - "number": 20, - "body": "## Sprint 2\n- [ ] #21\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 30, "body": "Fixes #11", "state": "open", - "mergeable": true, "merged": false, - "head": map[string]string{"sha": "aaa111", "ref": "fix-11", "label": "fix-11"}, - }, - { - "number": 31, "body": "Fixes #12", "state": "open", - "mergeable": false, "merged": false, - "head": map[string]string{"sha": "bbb222", "ref": "fix-12", "label": "fix-12"}, - }, - { - "number": 32, "body": "Resolves #21", "state": "open", - "mergeable": true, "merged": false, - "head": map[string]string{"sha": "ccc333", "ref": "fix-21", "label": "fix-21"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", "total_count": 1, "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - // Epic 10 has #11 and #12 unchecked; epic 20 has #21 unchecked. Total 3 signals. - require.Len(t, signals, 3, "expected three signals from two epics") - - childNumbers := map[int]bool{} - for _, sig := range signals { - childNumbers[sig.ChildNumber] = true - } - assert.True(t, childNumbers[11]) - assert.True(t, childNumbers[12]) - assert.True(t, childNumbers[21]) -} - -func TestForgejoSource_Poll_Good_CombinedStatusFetchErrorFallsToPending(t *testing.T) { - // When combined status fetch fails, check status should default to PENDING. - var statusFetched atomic.Bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}}, "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 10, "body": "Fixes #2", "state": "open", - "mergeable": true, "merged": false, - "head": map[string]string{"sha": "sha123", "ref": "fix", "label": "fix"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - statusFetched.Store(true) - w.WriteHeader(http.StatusInternalServerError) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - assert.True(t, statusFetched.Load(), "status endpoint should have been called") - assert.Equal(t, "PENDING", signals[0].CheckStatus, "failed status fetch should default to PENDING") -} - -func TestForgejoSource_Poll_Good_MixedReposFirstFailsSecondSucceeds(t *testing.T) { - // First repo fails (issues endpoint 500), second repo succeeds. - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/repos/bad-org/bad-repo/issues"): - w.WriteHeader(http.StatusInternalServerError) - - case strings.Contains(path, "/repos/good-org/good-repo/issues"): - issues := []map[string]any{ - { - "number": 1, "body": "- [ ] #2\n", - "labels": []map[string]string{{"name": "epic"}}, "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/repos/good-org/good-repo/pulls"): - prs := []map[string]any{ - { - "number": 10, "body": "Fixes #2", "state": "open", - "mergeable": true, "merged": false, - "head": map[string]string{"sha": "abc", "ref": "fix", "label": "fix"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", "total_count": 1, "statuses": []any{}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"bad-org/bad-repo", "good-org/good-repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - require.Len(t, signals, 1, "only the good repo should produce signals") - assert.Equal(t, "good-org", signals[0].RepoOwner) - assert.Equal(t, "good-repo", signals[0].RepoName) -} - -func TestForgejoSource_Report_Good_CommentBodyTable(t *testing.T) { - tests := []struct { - name string - result *jobrunner.ActionResult - wantContains []string - }{ - { - name: "successful action", - result: &jobrunner.ActionResult{ - Action: "enable_auto_merge", RepoOwner: "org", RepoName: "repo", - EpicNumber: 10, ChildNumber: 11, PRNumber: 20, Success: true, - }, - wantContains: []string{"enable_auto_merge", "succeeded", "#11", "PR #20"}, - }, - { - name: "failed action with error", - result: &jobrunner.ActionResult{ - Action: "tick_parent", RepoOwner: "org", RepoName: "repo", - EpicNumber: 10, ChildNumber: 11, PRNumber: 20, - Success: false, Error: "rate limit exceeded", - }, - wantContains: []string{"tick_parent", "failed", "#11", "PR #20", "rate limit exceeded"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - err := s.Report(context.Background(), tt.result) - require.NoError(t, err) - - for _, want := range tt.wantContains { - assert.Contains(t, capturedBody, want) - } - }) - } -} - -func TestForgejoSource_Report_Good_PostsToCorrectEpicIssue(t *testing.T) { - var capturedPath string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method == http.MethodPost { - capturedPath = r.URL.Path - } - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - result := &jobrunner.ActionResult{ - Action: "merge", RepoOwner: "test-org", RepoName: "test-repo", - EpicNumber: 42, ChildNumber: 7, PRNumber: 99, Success: true, - } - - err := s.Report(context.Background(), result) - require.NoError(t, err) - - expected := fmt.Sprintf("/api/v1/repos/%s/%s/issues/%d/comments", result.RepoOwner, result.RepoName, result.EpicNumber) - assert.Equal(t, expected, capturedPath, "comment should be posted on the epic issue") -} - -func TestForgejoSource_Poll_Good_SignalFieldCompleteness(t *testing.T) { - // Verify that all expected signal fields are populated correctly. - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 100, "body": "## Work\n- [ ] #101\n- [x] #102\n", - "labels": []map[string]string{{"name": "epic"}}, "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 200, "body": "Closes #101", "state": "open", - "mergeable": true, "merged": false, - "head": map[string]string{"sha": "deadbeef", "ref": "feature", "label": "feature"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - case strings.Contains(path, "/status"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "state": "success", "total_count": 2, - "statuses": []map[string]any{{"status": "success"}, {"status": "success"}}, - }) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"acme/widgets"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - sig := signals[0] - - assert.Equal(t, 100, sig.EpicNumber) - assert.Equal(t, 101, sig.ChildNumber) - assert.Equal(t, 200, sig.PRNumber) - assert.Equal(t, "acme", sig.RepoOwner) - assert.Equal(t, "widgets", sig.RepoName) - assert.Equal(t, "OPEN", sig.PRState) - assert.Equal(t, "MERGEABLE", sig.Mergeable) - assert.Equal(t, "SUCCESS", sig.CheckStatus) - assert.Equal(t, "deadbeef", sig.LastCommitSHA) - assert.False(t, sig.NeedsCoding) - assert.Equal(t, "acme/widgets", sig.RepoFullName()) -} - -func TestForgejoSource_Poll_Good_AllChildrenCheckedNoSignals(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, "body": "- [x] #2\n- [x] #3\n", - "labels": []map[string]string{{"name": "epic"}}, "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - _ = json.NewEncoder(w).Encode([]any{}) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals, "all children checked means no work to do") -} - -func TestForgejoSource_Poll_Good_NeedsCodingSignalFields(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - case strings.Contains(path, "/issues/7"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 7, "title": "Implement authentication", - "body": "Add OAuth2 support.", "state": "open", - "assignees": []map[string]any{{"login": "agent-bot", "username": "agent-bot"}}, - }) - - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 1, "body": "- [ ] #7\n", - "labels": []map[string]string{{"name": "epic"}}, "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - case strings.Contains(path, "/pulls"): - _ = json.NewEncoder(w).Encode([]any{}) - - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"org/repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - sig := signals[0] - assert.True(t, sig.NeedsCoding) - assert.Equal(t, "agent-bot", sig.Assignee) - assert.Equal(t, "Implement authentication", sig.IssueTitle) - assert.Contains(t, sig.IssueBody, "OAuth2 support") - assert.Equal(t, 0, sig.PRNumber, "PRNumber should be zero for NeedsCoding signals") -} diff --git a/pkg/jobrunner/forgejo/source_test.go b/pkg/jobrunner/forgejo/source_test.go deleted file mode 100644 index ce06ce7..0000000 --- a/pkg/jobrunner/forgejo/source_test.go +++ /dev/null @@ -1,177 +0,0 @@ -package forgejo - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// withVersion wraps an HTTP handler to serve the Forgejo /api/v1/version -// endpoint that the SDK calls during NewClient initialization. -func withVersion(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasSuffix(r.URL.Path, "/version") { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"version":"9.0.0"}`)) - return - } - next.ServeHTTP(w, r) - }) -} - -func newTestClient(t *testing.T, url string) *forge.Client { - t.Helper() - client, err := forge.New(url, "test-token") - require.NoError(t, err) - return client -} - -func TestForgejoSource_Name(t *testing.T) { - s := New(Config{}, nil) - assert.Equal(t, "forgejo", s.Name()) -} - -func TestForgejoSource_Poll_Good(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - w.Header().Set("Content-Type", "application/json") - - switch { - // List issues — return one epic - case strings.Contains(path, "/issues"): - issues := []map[string]any{ - { - "number": 10, - "body": "## Tasks\n- [ ] #11\n- [x] #12\n", - "labels": []map[string]string{{"name": "epic"}}, - "state": "open", - }, - } - _ = json.NewEncoder(w).Encode(issues) - - // List PRs — return one open PR linked to #11 - case strings.Contains(path, "/pulls"): - prs := []map[string]any{ - { - "number": 20, - "body": "Fixes #11", - "state": "open", - "mergeable": true, - "merged": false, - "head": map[string]string{"sha": "abc123", "ref": "feature", "label": "feature"}, - }, - } - _ = json.NewEncoder(w).Encode(prs) - - // Combined status - case strings.Contains(path, "/status"): - status := map[string]any{ - "state": "success", - "total_count": 1, - "statuses": []map[string]any{{"status": "success", "context": "ci"}}, - } - _ = json.NewEncoder(w).Encode(status) - - default: - w.WriteHeader(http.StatusNotFound) - } - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"test-org/test-repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - - require.Len(t, signals, 1) - sig := signals[0] - assert.Equal(t, 10, sig.EpicNumber) - assert.Equal(t, 11, sig.ChildNumber) - assert.Equal(t, 20, sig.PRNumber) - assert.Equal(t, "OPEN", sig.PRState) - assert.Equal(t, "MERGEABLE", sig.Mergeable) - assert.Equal(t, "SUCCESS", sig.CheckStatus) - assert.Equal(t, "test-org", sig.RepoOwner) - assert.Equal(t, "test-repo", sig.RepoName) - assert.Equal(t, "abc123", sig.LastCommitSHA) -} - -func TestForgejoSource_Poll_NoEpics(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode([]any{}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{Repos: []string{"test-org/test-repo"}}, client) - - signals, err := s.Poll(context.Background()) - require.NoError(t, err) - assert.Empty(t, signals) -} - -func TestForgejoSource_Report_Good(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - }))) - defer srv.Close() - - client := newTestClient(t, srv.URL) - s := New(Config{}, client) - - result := &jobrunner.ActionResult{ - Action: "enable_auto_merge", - RepoOwner: "test-org", - RepoName: "test-repo", - EpicNumber: 10, - ChildNumber: 11, - PRNumber: 20, - Success: true, - } - - err := s.Report(context.Background(), result) - require.NoError(t, err) - assert.Contains(t, capturedBody, "enable_auto_merge") - assert.Contains(t, capturedBody, "succeeded") -} - -func TestParseEpicChildren(t *testing.T) { - body := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n- [x] #3\n" - unchecked, checked := parseEpicChildren(body) - assert.Equal(t, []int{7, 8}, unchecked) - assert.Equal(t, []int{1, 3}, checked) -} - -func TestFindLinkedPR(t *testing.T) { - assert.Nil(t, findLinkedPR(nil, 7)) -} - -func TestSplitRepo(t *testing.T) { - owner, repo, err := splitRepo("host-uk/core") - require.NoError(t, err) - assert.Equal(t, "host-uk", owner) - assert.Equal(t, "core", repo) - - _, _, err = splitRepo("invalid") - assert.Error(t, err) - - _, _, err = splitRepo("") - assert.Error(t, err) -} diff --git a/pkg/jobrunner/handlers/completion.go b/pkg/jobrunner/handlers/completion.go deleted file mode 100644 index 0c500bd..0000000 --- a/pkg/jobrunner/handlers/completion.go +++ /dev/null @@ -1,88 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "time" - - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -const ( - ColorAgentComplete = "#0e8a16" // Green -) - -// CompletionHandler manages issue state when an agent finishes work. -type CompletionHandler struct { - forge *forge.Client -} - -// NewCompletionHandler creates a handler for agent completion events. -func NewCompletionHandler(client *forge.Client) *CompletionHandler { - return &CompletionHandler{ - forge: client, - } -} - -// Name returns the handler identifier. -func (h *CompletionHandler) Name() string { - return "completion" -} - -// Match returns true if the signal indicates an agent has finished a task. -func (h *CompletionHandler) Match(signal *jobrunner.PipelineSignal) bool { - return signal.Type == "agent_completion" -} - -// Execute updates the issue labels based on the completion status. -func (h *CompletionHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - // Remove in-progress label. - if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil { - _ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID) - } - - if signal.Success { - completeLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentComplete, ColorAgentComplete) - if err != nil { - return nil, coreerr.E("completion.Execute", "ensure label "+LabelAgentComplete, err) - } - - if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{completeLabel.ID}); err != nil { - return nil, coreerr.E("completion.Execute", "add completed label", err) - } - - if signal.Message != "" { - _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), signal.Message) - } - } else { - failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed) - if err != nil { - return nil, coreerr.E("completion.Execute", "ensure label "+LabelAgentFailed, err) - } - - if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}); err != nil { - return nil, coreerr.E("completion.Execute", "add failed label", err) - } - - msg := "Agent reported failure." - if signal.Error != "" { - msg += fmt.Sprintf("\n\nError: %s", signal.Error) - } - _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), msg) - } - - return &jobrunner.ActionResult{ - Action: "completion", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - EpicNumber: signal.EpicNumber, - ChildNumber: signal.ChildNumber, - Success: true, - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil -} diff --git a/pkg/jobrunner/handlers/completion_test.go b/pkg/jobrunner/handlers/completion_test.go deleted file mode 100644 index 5190215..0000000 --- a/pkg/jobrunner/handlers/completion_test.go +++ /dev/null @@ -1,291 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestCompletion_Name_Good(t *testing.T) { - h := NewCompletionHandler(nil) - assert.Equal(t, "completion", h.Name()) -} - -func TestCompletion_Match_Good_AgentCompletion(t *testing.T) { - h := NewCompletionHandler(nil) - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - } - assert.True(t, h.Match(sig)) -} - -func TestCompletion_Match_Bad_WrongType(t *testing.T) { - h := NewCompletionHandler(nil) - sig := &jobrunner.PipelineSignal{ - Type: "pr_update", - } - assert.False(t, h.Match(sig)) -} - -func TestCompletion_Match_Bad_EmptyType(t *testing.T) { - h := NewCompletionHandler(nil) - sig := &jobrunner.PipelineSignal{} - assert.False(t, h.Match(sig)) -} - -func TestCompletion_Execute_Good_Success(t *testing.T) { - var labelRemoved bool - var labelAdded bool - var commentPosted bool - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - // GetLabelByName (in-progress) — GET labels to find in-progress. - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - }) - - // RemoveIssueLabel (in-progress). - case r.Method == http.MethodDelete && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels/1": - labelRemoved = true - w.WriteHeader(http.StatusNoContent) - - // EnsureLabel (agent-completed) — POST to create. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"}) - - // AddIssueLabels. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels": - labelAdded = true - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 2, "name": "agent-completed"}}) - - // CreateIssueComment. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/comments": - commentPosted = true - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - commentBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "test-org", - RepoName: "test-repo", - ChildNumber: 5, - EpicNumber: 3, - Success: true, - Message: "Task completed successfully", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "completion", result.Action) - assert.Equal(t, "test-org", result.RepoOwner) - assert.Equal(t, "test-repo", result.RepoName) - assert.Equal(t, 3, result.EpicNumber) - assert.Equal(t, 5, result.ChildNumber) - assert.True(t, labelRemoved, "in-progress label should be removed") - assert.True(t, labelAdded, "agent-completed label should be added") - assert.True(t, commentPosted, "comment should be posted") - assert.Contains(t, commentBody, "Task completed successfully") -} - -func TestCompletion_Execute_Good_Failure(t *testing.T) { - var labelAdded bool - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{}) - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"}) - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/labels": - labelAdded = true - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 3, "name": "agent-failed"}}) - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/test-org/test-repo/issues/5/comments": - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - commentBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "test-org", - RepoName: "test-repo", - ChildNumber: 5, - EpicNumber: 3, - Success: false, - Error: "tests failed", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.True(t, result.Success) // The handler itself succeeded - assert.Equal(t, "completion", result.Action) - assert.True(t, labelAdded, "agent-failed label should be added") - assert.Contains(t, commentBody, "Agent reported failure") - assert.Contains(t, commentBody, "tests failed") -} - -func TestCompletion_Execute_Good_FailureNoError(t *testing.T) { - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/comments": - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - commentBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 1, - Success: false, - Error: "", // No error message. - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Contains(t, commentBody, "Agent reported failure") - assert.NotContains(t, commentBody, "Error:") // No error detail. -} - -func TestCompletion_Execute_Good_SuccessNoMessage(t *testing.T) { - var commentPosted bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/issues/1/comments": - commentPosted = true - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 1, - Success: true, - Message: "", // No message. - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.False(t, commentPosted, "no comment should be posted when message is empty") -} - -func TestCompletion_Execute_Bad_EnsureLabelFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - // Return empty so EnsureLabel tries to create. - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - // Label creation fails. - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 1, - Success: true, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "ensure label") -} diff --git a/pkg/jobrunner/handlers/coverage_boost_test.go b/pkg/jobrunner/handlers/coverage_boost_test.go deleted file mode 100644 index b8bd34f..0000000 --- a/pkg/jobrunner/handlers/coverage_boost_test.go +++ /dev/null @@ -1,704 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - agentci "forge.lthn.ai/core/agent/pkg/orchestrator" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// --- Dispatch: Execute with invalid repo name --- - -func TestDispatch_Execute_Bad_InvalidRepoNameSpecialChars(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "valid-org", - RepoName: "repo$bad!", - ChildNumber: 1, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid repo name") -} - -// --- Dispatch: Execute when EnsureLabel fails --- - -func TestDispatch_Execute_Bad_EnsureLabelCreationFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 1, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "ensure label") -} - -// dispatchMockServer creates a standard mock server for dispatch tests. -// It handles all the Forgejo API calls needed for a full dispatch flow. -func dispatchMockServer(t *testing.T) *httptest.Server { - t.Helper() - return httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - // GetLabelByName / list labels - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - {"id": 2, "name": "agent-ready", "color": "#00ff00"}, - }) - - // CreateLabel (shouldn't normally be needed since we return it above) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress", "color": "#1d76db"}) - - // GetIssue (returns issue with no label to trigger the full dispatch flow) - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - w.WriteHeader(http.StatusNotFound) // Issue not found => full dispatch flow - - // AssignIssue - case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5}) - - // AddIssueLabels - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}}) - - // RemoveIssueLabel - case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"): - w.WriteHeader(http.StatusNoContent) - - // CreateIssueComment - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"): - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) -} - -// --- Dispatch: Execute when GetIssue returns 404 (full dispatch path) --- - -func TestDispatch_Execute_Good_GetIssueNotFound(t *testing.T) { - srv := dispatchMockServer(t) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - EpicNumber: 3, - IssueTitle: "Test issue", - IssueBody: "Test body", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.Equal(t, "dispatch", result.Action) -} - -// --- Completion: Execute when AddIssueLabels fails for success case --- - -func TestCompletion_Execute_Bad_AddCompleteLabelFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/repo/labels"): - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"}) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - Success: true, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "add completed label") -} - -// --- Completion: Execute when AddIssueLabels fails for failure case --- - -func TestCompletion_Execute_Bad_AddFailLabelFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/repo/labels"): - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"}) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - Success: false, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "add failed label") -} - -// --- Completion: Execute with EnsureLabel failure on failure path --- - -func TestCompletion_Execute_Bad_FailedPathEnsureLabelFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/labels"): - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - sig := &jobrunner.PipelineSignal{ - Type: "agent_completion", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 1, - Success: false, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "ensure label") -} - -// --- EnableAutoMerge: additional edge case --- - -func TestEnableAutoMerge_Match_Bad_PendingChecks(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "PENDING", - } - assert.False(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Execute_Bad_InternalServerError(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewEnableAutoMergeHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "merge failed") -} - -// --- PublishDraft: Match with MERGED state --- - -func TestPublishDraft_Match_Bad_MergedState(t *testing.T) { - h := NewPublishDraftHandler(nil) - sig := &jobrunner.PipelineSignal{ - IsDraft: true, - PRState: "MERGED", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -// --- SendFixCommand: Execute merge conflict message --- - -func TestSendFixCommand_Execute_Good_MergeConflictMessage(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method == http.MethodPost { - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - capturedBody = body["body"] - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewSendFixCommandHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - Mergeable: "CONFLICTING", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Contains(t, capturedBody, "fix the merge conflict") -} - -// --- DismissReviews: Execute with stale review that gets dismissed --- - -func TestDismissReviews_Execute_Good_StaleReviewDismissed(t *testing.T) { - var dismissCalled bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/reviews") { - reviews := []map[string]any{ - { - "id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, - "body": "fix it", "commit_id": "abc123", - }, - } - _ = json.NewEncoder(w).Encode(reviews) - return - } - - if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/dismissals") { - dismissCalled = true - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "state": "DISMISSED"}) - return - } - - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewDismissReviewsHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - PRState: "OPEN", - ThreadsTotal: 1, - ThreadsResolved: 0, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.True(t, dismissCalled) -} - -// --- TickParent: Execute ticks and closes --- - -func TestTickParent_Execute_Good_TicksCheckboxAndCloses(t *testing.T) { - epicBody := "## Tasks\n- [ ] #7\n- [ ] #8\n" - var editedBody string - var closedIssue bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"): - var body map[string]any - _ = json.NewDecoder(r.Body).Decode(&body) - if b, ok := body["body"].(string); ok { - editedBody = b - } - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": editedBody, - "title": "Epic", - }) - case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/7"): - closedIssue = true - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 7, - "state": "closed", - }) - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Contains(t, editedBody, "- [x] #7") - assert.True(t, closedIssue) -} - -// --- Dispatch: DualRun mode --- - -func TestDispatch_Execute_Good_DualRunModeDispatch(t *testing.T) { - srv := dispatchMockServer(t) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - spinner := agentci.NewSpinner( - agentci.ClothoConfig{Strategy: "clotho-verified"}, - map[string]agentci.AgentConfig{ - "darbs-claude": { - Host: "localhost", - QueueDir: "/tmp/nonexistent-queue", - Active: true, - Model: "sonnet", - DualRun: true, - }, - }, - ) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - EpicNumber: 3, - IssueTitle: "Test issue", - IssueBody: "Test body", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.Equal(t, "dispatch", result.Action) -} - -// --- TickParent: ChildNumber not found in epic body --- - -func TestTickParent_Execute_Good_ChildNotInBody(t *testing.T) { - epicBody := "## Tasks\n- [ ] #99\n" - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42") { - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 42, - ChildNumber: 50, - PRNumber: 100, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) -} - -// --- Dispatch: AssignIssue fails (warn, continue) --- - -func TestDispatch_Execute_Good_AssignIssueFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - {"id": 2, "name": "agent-ready", "color": "#00ff00"}, - }) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"}) - // GetIssue returns issue with NO special labels - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{ - "id": 5, "number": 5, "title": "Test Issue", - "labels": []map[string]any{}, - }) - // AssignIssue FAILS - case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(`{"message":"assign failed"}`)) - // AddIssueLabels succeeds - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}}) - case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"): - w.WriteHeader(http.StatusNoContent) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"): - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"}) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 5, - PRNumber: 10, - RepoOwner: "org", - RepoName: "repo", - Assignee: "darbs-claude", - IssueTitle: "Test Issue", - IssueBody: "Test body", - } - - // Should not return error because AssignIssue failure is only a warning. - result, err := h.Execute(context.Background(), signal) - // secureTransfer will fail because SSH isn't available, but we exercised the assign-error path. - _ = result - _ = err -} - -// --- Dispatch: AddIssueLabels fails --- - -func TestDispatch_Execute_Bad_AddIssueLabelsError(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - }) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"}) - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{ - "id": 5, "number": 5, "title": "Test Issue", - "labels": []map[string]any{}, - }) - case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5}) - // AddIssueLabels FAILS - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - w.WriteHeader(http.StatusInternalServerError) - _, _ = w.Write([]byte(`{"message":"label add failed"}`)) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 5, - PRNumber: 10, - RepoOwner: "org", - RepoName: "repo", - Assignee: "darbs-claude", - IssueTitle: "Test Issue", - IssueBody: "Test body", - } - - _, err := h.Execute(context.Background(), signal) - assert.Error(t, err) - assert.Contains(t, err.Error(), "add in-progress label") -} - -// --- Dispatch: GetIssue returns issue with existing labels not matching --- - -func TestDispatch_Execute_Good_IssueFoundNoSpecialLabels(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - {"id": 2, "name": "agent-ready", "color": "#00ff00"}, - }) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"}) - // GetIssue returns issue with unrelated labels - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{ - "id": 5, "number": 5, "title": "Test Issue", - "labels": []map[string]any{ - {"id": 10, "name": "enhancement"}, - }, - }) - case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5}) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 1, "name": "in-progress"}}) - case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"): - w.WriteHeader(http.StatusNoContent) - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/issues/5/comments"): - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": "dispatched"}) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 5, - PRNumber: 10, - RepoOwner: "org", - RepoName: "repo", - Assignee: "darbs-claude", - IssueTitle: "Test Issue", - IssueBody: "Test body", - } - - // Execute will proceed past label check and try SSH (which fails). - result, err := h.Execute(context.Background(), signal) - // Should either succeed (if somehow SSH works) or fail at secureTransfer. - _ = result - _ = err -} diff --git a/pkg/jobrunner/handlers/dispatch.go b/pkg/jobrunner/handlers/dispatch.go deleted file mode 100644 index 46ad44f..0000000 --- a/pkg/jobrunner/handlers/dispatch.go +++ /dev/null @@ -1,290 +0,0 @@ -package handlers - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "path/filepath" - "time" - - agentci "forge.lthn.ai/core/agent/pkg/orchestrator" - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" - coreerr "forge.lthn.ai/core/go-log" -) - -const ( - LabelAgentReady = "agent-ready" - LabelInProgress = "in-progress" - LabelAgentFailed = "agent-failed" - LabelAgentComplete = "agent-completed" - - ColorInProgress = "#1d76db" // Blue - ColorAgentFailed = "#c0392b" // Red -) - -// DispatchTicket is the JSON payload written to the agent's queue. -// The ForgeToken is transferred separately via a .env file with 0600 permissions. -type DispatchTicket struct { - ID string `json:"id"` - RepoOwner string `json:"repo_owner"` - RepoName string `json:"repo_name"` - IssueNumber int `json:"issue_number"` - IssueTitle string `json:"issue_title"` - IssueBody string `json:"issue_body"` - TargetBranch string `json:"target_branch"` - EpicNumber int `json:"epic_number"` - ForgeURL string `json:"forge_url"` - ForgeUser string `json:"forgejo_user"` - Model string `json:"model,omitempty"` - Runner string `json:"runner,omitempty"` - VerifyModel string `json:"verify_model,omitempty"` - DualRun bool `json:"dual_run"` - CreatedAt string `json:"created_at"` -} - -// DispatchHandler dispatches coding work to remote agent machines via SSH. -type DispatchHandler struct { - forge *forge.Client - forgeURL string - token string - spinner *agentci.Spinner -} - -// NewDispatchHandler creates a handler that dispatches tickets to agent machines. -func NewDispatchHandler(client *forge.Client, forgeURL, token string, spinner *agentci.Spinner) *DispatchHandler { - return &DispatchHandler{ - forge: client, - forgeURL: forgeURL, - token: token, - spinner: spinner, - } -} - -// Name returns the handler identifier. -func (h *DispatchHandler) Name() string { - return "dispatch" -} - -// Match returns true for signals where a child issue needs coding (no PR yet) -// and the assignee is a known agent (by config key or Forgejo username). -func (h *DispatchHandler) Match(signal *jobrunner.PipelineSignal) bool { - if !signal.NeedsCoding { - return false - } - _, _, ok := h.spinner.FindByForgejoUser(signal.Assignee) - return ok -} - -// Execute creates a ticket JSON and transfers it securely to the agent's queue directory. -func (h *DispatchHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - agentName, agent, ok := h.spinner.FindByForgejoUser(signal.Assignee) - if !ok { - return nil, coreerr.E("handlers.Dispatch.Execute", "unknown agent: "+signal.Assignee, nil) - } - - // Sanitize inputs to prevent path traversal. - safeOwner, err := agentci.SanitizePath(signal.RepoOwner) - if err != nil { - return nil, coreerr.E("handlers.Dispatch.Execute", "invalid repo owner", err) - } - safeRepo, err := agentci.SanitizePath(signal.RepoName) - if err != nil { - return nil, coreerr.E("handlers.Dispatch.Execute", "invalid repo name", err) - } - - // Ensure in-progress label exists on repo. - inProgressLabel, err := h.forge.EnsureLabel(safeOwner, safeRepo, LabelInProgress, ColorInProgress) - if err != nil { - return nil, coreerr.E("handlers.Dispatch.Execute", "ensure label "+LabelInProgress, err) - } - - // Check if already in progress to prevent double-dispatch. - issue, err := h.forge.GetIssue(safeOwner, safeRepo, int64(signal.ChildNumber)) - if err == nil { - for _, l := range issue.Labels { - if l.Name == LabelInProgress || l.Name == LabelAgentComplete { - coreerr.Info("issue already processed, skipping", "issue", signal.ChildNumber) - return &jobrunner.ActionResult{ - Action: "dispatch", - Success: true, - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - } - } - - // Assign agent and add in-progress label. - if err := h.forge.AssignIssue(safeOwner, safeRepo, int64(signal.ChildNumber), []string{signal.Assignee}); err != nil { - coreerr.Warn("failed to assign agent, continuing", "err", err) - } - - if err := h.forge.AddIssueLabels(safeOwner, safeRepo, int64(signal.ChildNumber), []int64{inProgressLabel.ID}); err != nil { - return nil, coreerr.E("handlers.Dispatch.Execute", "add in-progress label", err) - } - - // Remove agent-ready label if present. - if readyLabel, err := h.forge.GetLabelByName(safeOwner, safeRepo, LabelAgentReady); err == nil { - _ = h.forge.RemoveIssueLabel(safeOwner, safeRepo, int64(signal.ChildNumber), readyLabel.ID) - } - - // Clotho planning — determine execution mode. - runMode := h.spinner.DeterminePlan(signal, agentName) - verifyModel := "" - if runMode == agentci.ModeDual { - verifyModel = h.spinner.GetVerifierModel(agentName) - } - - // Build ticket. - targetBranch := "new" // TODO: resolve from epic or repo default - ticketID := fmt.Sprintf("%s-%s-%d-%d", safeOwner, safeRepo, signal.ChildNumber, time.Now().Unix()) - - ticket := DispatchTicket{ - ID: ticketID, - RepoOwner: safeOwner, - RepoName: safeRepo, - IssueNumber: signal.ChildNumber, - IssueTitle: signal.IssueTitle, - IssueBody: signal.IssueBody, - TargetBranch: targetBranch, - EpicNumber: signal.EpicNumber, - ForgeURL: h.forgeURL, - ForgeUser: signal.Assignee, - Model: agent.Model, - Runner: agent.Runner, - VerifyModel: verifyModel, - DualRun: runMode == agentci.ModeDual, - CreatedAt: time.Now().UTC().Format(time.RFC3339), - } - - ticketJSON, err := json.MarshalIndent(ticket, "", " ") - if err != nil { - h.failDispatch(signal, "Failed to marshal ticket JSON") - return nil, coreerr.E("handlers.Dispatch.Execute", "marshal ticket", err) - } - - // Check if ticket already exists on agent (dedup). - ticketName := fmt.Sprintf("ticket-%s-%s-%d.json", safeOwner, safeRepo, signal.ChildNumber) - if h.ticketExists(ctx, agent, ticketName) { - coreerr.Info("ticket already queued, skipping", "ticket", ticketName, "agent", signal.Assignee) - return &jobrunner.ActionResult{ - Action: "dispatch", - RepoOwner: safeOwner, - RepoName: safeRepo, - EpicNumber: signal.EpicNumber, - ChildNumber: signal.ChildNumber, - Success: true, - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - - // Transfer ticket JSON. - remoteTicketPath := filepath.Join(agent.QueueDir, ticketName) - if err := h.secureTransfer(ctx, agent, remoteTicketPath, ticketJSON, 0644); err != nil { - h.failDispatch(signal, fmt.Sprintf("Ticket transfer failed: %v", err)) - return &jobrunner.ActionResult{ - Action: "dispatch", - RepoOwner: safeOwner, - RepoName: safeRepo, - EpicNumber: signal.EpicNumber, - ChildNumber: signal.ChildNumber, - Success: false, - Error: fmt.Sprintf("transfer ticket: %v", err), - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - - // Transfer token via separate .env file with 0600 permissions. - envContent := fmt.Sprintf("FORGE_TOKEN=%s\n", h.token) - remoteEnvPath := filepath.Join(agent.QueueDir, fmt.Sprintf(".env.%s", ticketID)) - if err := h.secureTransfer(ctx, agent, remoteEnvPath, []byte(envContent), 0600); err != nil { - // Clean up the ticket if env transfer fails. - _ = h.runRemote(ctx, agent, fmt.Sprintf("rm -f %s", agentci.EscapeShellArg(remoteTicketPath))) - h.failDispatch(signal, fmt.Sprintf("Token transfer failed: %v", err)) - return &jobrunner.ActionResult{ - Action: "dispatch", - RepoOwner: safeOwner, - RepoName: safeRepo, - EpicNumber: signal.EpicNumber, - ChildNumber: signal.ChildNumber, - Success: false, - Error: fmt.Sprintf("transfer token: %v", err), - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - - // Comment on issue. - modeStr := "Standard" - if runMode == agentci.ModeDual { - modeStr = "Clotho Verified (Dual Run)" - } - comment := fmt.Sprintf("Dispatched to **%s** agent queue.\nMode: **%s**", signal.Assignee, modeStr) - _ = h.forge.CreateIssueComment(safeOwner, safeRepo, int64(signal.ChildNumber), comment) - - return &jobrunner.ActionResult{ - Action: "dispatch", - RepoOwner: safeOwner, - RepoName: safeRepo, - EpicNumber: signal.EpicNumber, - ChildNumber: signal.ChildNumber, - Success: true, - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil -} - -// failDispatch handles cleanup when dispatch fails (adds failed label, removes in-progress). -func (h *DispatchHandler) failDispatch(signal *jobrunner.PipelineSignal, reason string) { - if failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed); err == nil { - _ = h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}) - } - - if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil { - _ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID) - } - - _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), fmt.Sprintf("Agent dispatch failed: %s", reason)) -} - -// secureTransfer writes data to a remote path via SSH stdin, preventing command injection. -func (h *DispatchHandler) secureTransfer(ctx context.Context, agent agentci.AgentConfig, remotePath string, data []byte, mode int) error { - safeRemotePath := agentci.EscapeShellArg(remotePath) - remoteCmd := fmt.Sprintf("cat > %s && chmod %o %s", safeRemotePath, mode, safeRemotePath) - - cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, remoteCmd) - cmd.Stdin = bytes.NewReader(data) - - output, err := cmd.CombinedOutput() - if err != nil { - return coreerr.E("dispatch.transfer", fmt.Sprintf("ssh to %s failed: %s", agent.Host, string(output)), err) - } - return nil -} - -// runRemote executes a command on the agent via SSH. -func (h *DispatchHandler) runRemote(ctx context.Context, agent agentci.AgentConfig, cmdStr string) error { - cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, cmdStr) - return cmd.Run() -} - -// ticketExists checks if a ticket file already exists in queue, active, or done. -func (h *DispatchHandler) ticketExists(ctx context.Context, agent agentci.AgentConfig, ticketName string) bool { - safeTicket, err := agentci.SanitizePath(ticketName) - if err != nil { - return false - } - qDir := agent.QueueDir - checkCmd := fmt.Sprintf( - "test -f %s/%s || test -f %s/../active/%s || test -f %s/../done/%s", - qDir, safeTicket, qDir, safeTicket, qDir, safeTicket, - ) - cmd := agentci.SecureSSHCommandContext(ctx, agent.Host, checkCmd) - return cmd.Run() == nil -} diff --git a/pkg/jobrunner/handlers/dispatch_test.go b/pkg/jobrunner/handlers/dispatch_test.go deleted file mode 100644 index a742df6..0000000 --- a/pkg/jobrunner/handlers/dispatch_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - agentci "forge.lthn.ai/core/agent/pkg/orchestrator" - "forge.lthn.ai/core/agent/pkg/jobrunner" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// newTestSpinner creates a Spinner with the given agents for testing. -func newTestSpinner(agents map[string]agentci.AgentConfig) *agentci.Spinner { - return agentci.NewSpinner(agentci.ClothoConfig{Strategy: "direct"}, agents) -} - -// --- Match tests --- - -func TestDispatch_Match_Good_NeedsCoding(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - } - assert.True(t, h.Match(sig)) -} - -func TestDispatch_Match_Good_MultipleAgents(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - "local-codex": {Host: "localhost", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "local-codex", - } - assert.True(t, h.Match(sig)) -} - -func TestDispatch_Match_Bad_HasPR(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: false, - PRNumber: 7, - Assignee: "darbs-claude", - } - assert.False(t, h.Match(sig)) -} - -func TestDispatch_Match_Bad_UnknownAgent(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "unknown-user", - } - assert.False(t, h.Match(sig)) -} - -func TestDispatch_Match_Bad_NotAssigned(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "", - } - assert.False(t, h.Match(sig)) -} - -func TestDispatch_Match_Bad_EmptyAgentMap(t *testing.T) { - spinner := newTestSpinner(map[string]agentci.AgentConfig{}) - h := NewDispatchHandler(nil, "", "", spinner) - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - } - assert.False(t, h.Match(sig)) -} - -// --- Name test --- - -func TestDispatch_Name_Good(t *testing.T) { - spinner := newTestSpinner(nil) - h := NewDispatchHandler(nil, "", "", spinner) - assert.Equal(t, "dispatch", h.Name()) -} - -// --- Execute tests --- - -func TestDispatch_Execute_Bad_UnknownAgent(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "claude@192.168.0.201", QueueDir: "~/ai-work/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "nonexistent-agent", - RepoOwner: "host-uk", - RepoName: "core", - ChildNumber: 1, - } - - _, err := h.Execute(context.Background(), sig) - require.Error(t, err) - assert.Contains(t, err.Error(), "unknown agent") -} - -func TestDispatch_TicketJSON_Good(t *testing.T) { - ticket := DispatchTicket{ - ID: "host-uk-core-5-1234567890", - RepoOwner: "host-uk", - RepoName: "core", - IssueNumber: 5, - IssueTitle: "Fix the thing", - IssueBody: "Please fix this bug", - TargetBranch: "new", - EpicNumber: 3, - ForgeURL: "https://forge.lthn.ai", - ForgeUser: "darbs-claude", - Model: "sonnet", - Runner: "claude", - DualRun: false, - CreatedAt: "2026-02-09T12:00:00Z", - } - - data, err := json.MarshalIndent(ticket, "", " ") - require.NoError(t, err) - - var decoded map[string]any - err = json.Unmarshal(data, &decoded) - require.NoError(t, err) - - assert.Equal(t, "host-uk-core-5-1234567890", decoded["id"]) - assert.Equal(t, "host-uk", decoded["repo_owner"]) - assert.Equal(t, "core", decoded["repo_name"]) - assert.Equal(t, float64(5), decoded["issue_number"]) - assert.Equal(t, "Fix the thing", decoded["issue_title"]) - assert.Equal(t, "Please fix this bug", decoded["issue_body"]) - assert.Equal(t, "new", decoded["target_branch"]) - assert.Equal(t, float64(3), decoded["epic_number"]) - assert.Equal(t, "https://forge.lthn.ai", decoded["forge_url"]) - assert.Equal(t, "darbs-claude", decoded["forgejo_user"]) - assert.Equal(t, "sonnet", decoded["model"]) - assert.Equal(t, "claude", decoded["runner"]) - // Token should NOT be present in the ticket. - _, hasToken := decoded["forge_token"] - assert.False(t, hasToken, "forge_token must not be in ticket JSON") -} - -func TestDispatch_TicketJSON_Good_DualRun(t *testing.T) { - ticket := DispatchTicket{ - ID: "test-dual", - RepoOwner: "host-uk", - RepoName: "core", - IssueNumber: 1, - ForgeURL: "https://forge.lthn.ai", - Model: "gemini-2.0-flash", - VerifyModel: "gemini-1.5-pro", - DualRun: true, - } - - data, err := json.Marshal(ticket) - require.NoError(t, err) - - var roundtrip DispatchTicket - err = json.Unmarshal(data, &roundtrip) - require.NoError(t, err) - assert.True(t, roundtrip.DualRun) - assert.Equal(t, "gemini-1.5-pro", roundtrip.VerifyModel) -} - -func TestDispatch_TicketJSON_Good_OmitsEmptyModelRunner(t *testing.T) { - ticket := DispatchTicket{ - ID: "test-1", - RepoOwner: "host-uk", - RepoName: "core", - IssueNumber: 1, - TargetBranch: "new", - ForgeURL: "https://forge.lthn.ai", - } - - data, err := json.MarshalIndent(ticket, "", " ") - require.NoError(t, err) - - var decoded map[string]any - err = json.Unmarshal(data, &decoded) - require.NoError(t, err) - - _, hasModel := decoded["model"] - _, hasRunner := decoded["runner"] - assert.False(t, hasModel, "model should be omitted when empty") - assert.False(t, hasRunner, "runner should be omitted when empty") -} - -func TestDispatch_TicketJSON_Good_ModelRunnerVariants(t *testing.T) { - tests := []struct { - name string - model string - runner string - }{ - {"claude-sonnet", "sonnet", "claude"}, - {"claude-opus", "opus", "claude"}, - {"codex-default", "", "codex"}, - {"gemini-default", "", "gemini"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ticket := DispatchTicket{ - ID: "test-" + tt.name, - RepoOwner: "host-uk", - RepoName: "core", - IssueNumber: 1, - TargetBranch: "new", - ForgeURL: "https://forge.lthn.ai", - Model: tt.model, - Runner: tt.runner, - } - - data, err := json.Marshal(ticket) - require.NoError(t, err) - - var roundtrip DispatchTicket - err = json.Unmarshal(data, &roundtrip) - require.NoError(t, err) - assert.Equal(t, tt.model, roundtrip.Model) - assert.Equal(t, tt.runner, roundtrip.Runner) - }) - } -} - -func TestDispatch_Execute_Good_PostsComment(t *testing.T) { - var commentPosted bool - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/labels": - json.NewEncoder(w).Encode([]any{}) - return - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/labels": - json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress", "color": "#1d76db"}) - return - - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5": - json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5, "labels": []any{}, "title": "Test"}) - return - - case r.Method == http.MethodPatch && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5": - json.NewEncoder(w).Encode(map[string]any{"id": 5, "number": 5}) - return - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/labels": - json.NewEncoder(w).Encode([]any{map[string]any{"id": 1, "name": "in-progress"}}) - return - - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/host-uk/core/issues/5/comments": - commentPosted = true - var body map[string]string - _ = json.NewDecoder(r.Body).Decode(&body) - commentBody = body["body"] - json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]}) - return - } - - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]any{}) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/nonexistent-queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "host-uk", - RepoName: "core", - ChildNumber: 5, - EpicNumber: 3, - IssueTitle: "Test issue", - IssueBody: "Test body", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.Equal(t, "dispatch", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core", result.RepoName) - assert.Equal(t, 3, result.EpicNumber) - assert.Equal(t, 5, result.ChildNumber) - - if result.Success { - assert.True(t, commentPosted) - assert.Contains(t, commentBody, "darbs-claude") - } -} diff --git a/pkg/jobrunner/handlers/enable_auto_merge.go b/pkg/jobrunner/handlers/enable_auto_merge.go deleted file mode 100644 index 4c05894..0000000 --- a/pkg/jobrunner/handlers/enable_auto_merge.go +++ /dev/null @@ -1,58 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "time" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// EnableAutoMergeHandler merges a PR that is ready using squash strategy. -type EnableAutoMergeHandler struct { - forge *forge.Client -} - -// NewEnableAutoMergeHandler creates a handler that merges ready PRs. -func NewEnableAutoMergeHandler(f *forge.Client) *EnableAutoMergeHandler { - return &EnableAutoMergeHandler{forge: f} -} - -// Name returns the handler identifier. -func (h *EnableAutoMergeHandler) Name() string { - return "enable_auto_merge" -} - -// Match returns true when the PR is open, not a draft, mergeable, checks -// are passing, and there are no unresolved review threads. -func (h *EnableAutoMergeHandler) Match(signal *jobrunner.PipelineSignal) bool { - return signal.PRState == "OPEN" && - !signal.IsDraft && - signal.Mergeable == "MERGEABLE" && - signal.CheckStatus == "SUCCESS" && - !signal.HasUnresolvedThreads() -} - -// Execute merges the pull request with squash strategy. -func (h *EnableAutoMergeHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - err := h.forge.MergePullRequest(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), "squash") - - result := &jobrunner.ActionResult{ - Action: "enable_auto_merge", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: err == nil, - Timestamp: time.Now(), - Duration: time.Since(start), - } - - if err != nil { - result.Error = fmt.Sprintf("merge failed: %v", err) - } - - return result, nil -} diff --git a/pkg/jobrunner/handlers/enable_auto_merge_test.go b/pkg/jobrunner/handlers/enable_auto_merge_test.go deleted file mode 100644 index 55a9e39..0000000 --- a/pkg/jobrunner/handlers/enable_auto_merge_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestEnableAutoMerge_Match_Good(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - assert.True(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Match_Bad_Draft(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: true, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - assert.False(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Match_Bad_UnresolvedThreads(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 5, - ThreadsResolved: 3, - } - assert.False(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Execute_Good(t *testing.T) { - var capturedPath string - var capturedMethod string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedMethod = r.Method - capturedPath = r.URL.Path - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewEnableAutoMergeHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - PRNumber: 55, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "enable_auto_merge", result.Action) - assert.Equal(t, http.MethodPost, capturedMethod) - assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/55/merge", capturedPath) -} - -func TestEnableAutoMerge_Execute_Bad_MergeFailed(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(map[string]string{"message": "merge conflict"}) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewEnableAutoMergeHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - PRNumber: 55, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.False(t, result.Success) - assert.Contains(t, result.Error, "merge failed") -} diff --git a/pkg/jobrunner/handlers/handlers_extra_test.go b/pkg/jobrunner/handlers/handlers_extra_test.go deleted file mode 100644 index fba7c94..0000000 --- a/pkg/jobrunner/handlers/handlers_extra_test.go +++ /dev/null @@ -1,583 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - agentci "forge.lthn.ai/core/agent/pkg/orchestrator" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// --- Name tests for all handlers --- - -func TestEnableAutoMerge_Name_Good(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - assert.Equal(t, "enable_auto_merge", h.Name()) -} - -func TestPublishDraft_Name_Good(t *testing.T) { - h := NewPublishDraftHandler(nil) - assert.Equal(t, "publish_draft", h.Name()) -} - -func TestDismissReviews_Name_Good(t *testing.T) { - h := NewDismissReviewsHandler(nil) - assert.Equal(t, "dismiss_reviews", h.Name()) -} - -func TestSendFixCommand_Name_Good(t *testing.T) { - h := NewSendFixCommandHandler(nil) - assert.Equal(t, "send_fix_command", h.Name()) -} - -func TestTickParent_Name_Good(t *testing.T) { - h := NewTickParentHandler(nil) - assert.Equal(t, "tick_parent", h.Name()) -} - -// --- Additional Match tests --- - -func TestEnableAutoMerge_Match_Bad_Closed(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "CLOSED", - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Match_Bad_ChecksFailing(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "FAILURE", - } - assert.False(t, h.Match(sig)) -} - -func TestEnableAutoMerge_Match_Bad_Conflicting(t *testing.T) { - h := NewEnableAutoMergeHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "CONFLICTING", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -func TestPublishDraft_Match_Bad_Closed(t *testing.T) { - h := NewPublishDraftHandler(nil) - sig := &jobrunner.PipelineSignal{ - IsDraft: true, - PRState: "CLOSED", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -func TestDismissReviews_Match_Bad_Closed(t *testing.T) { - h := NewDismissReviewsHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "CLOSED", - ThreadsTotal: 3, - ThreadsResolved: 1, - } - assert.False(t, h.Match(sig)) -} - -func TestDismissReviews_Match_Bad_NoThreads(t *testing.T) { - h := NewDismissReviewsHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - assert.False(t, h.Match(sig)) -} - -func TestSendFixCommand_Match_Bad_Closed(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "CLOSED", - Mergeable: "CONFLICTING", - } - assert.False(t, h.Match(sig)) -} - -func TestSendFixCommand_Match_Bad_NoIssues(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -func TestSendFixCommand_Match_Good_ThreadsFailure(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "FAILURE", - ThreadsTotal: 2, - ThreadsResolved: 0, - } - assert.True(t, h.Match(sig)) -} - -func TestTickParent_Match_Bad_Closed(t *testing.T) { - h := NewTickParentHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "CLOSED", - } - assert.False(t, h.Match(sig)) -} - -// --- Additional Execute tests --- - -func TestPublishDraft_Execute_Bad_ServerError(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewPublishDraftHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "publish draft failed") -} - -func TestSendFixCommand_Execute_Good_Reviews(t *testing.T) { - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method == http.MethodPost { - b, _ := io.ReadAll(r.Body) - capturedBody = string(b) - w.WriteHeader(http.StatusCreated) - _, _ = w.Write([]byte(`{"id":1}`)) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewSendFixCommandHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 5, - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "FAILURE", - ThreadsTotal: 2, - ThreadsResolved: 0, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Contains(t, capturedBody, "fix the code reviews") -} - -func TestSendFixCommand_Execute_Bad_CommentFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewSendFixCommandHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - Mergeable: "CONFLICTING", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "post comment failed") -} - -func TestTickParent_Execute_Good_AlreadyTicked(t *testing.T) { - epicBody := "## Tasks\n- [x] #7\n- [ ] #8\n" - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42") { - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Equal(t, "tick_parent", result.Action) -} - -func TestTickParent_Execute_Bad_FetchEpicFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 999, - ChildNumber: 1, - PRState: "MERGED", - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "fetch epic") -} - -func TestTickParent_Execute_Bad_EditEpicFails(t *testing.T) { - epicBody := "## Tasks\n- [ ] #7\n" - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"): - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "edit epic failed") -} - -func TestTickParent_Execute_Bad_CloseChildFails(t *testing.T) { - epicBody := "## Tasks\n- [ ] #7\n" - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/42"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/42"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": strings.Replace(epicBody, "- [ ] #7", "- [x] #7", 1), - "title": "Epic", - }) - case r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/issues/7"): - w.WriteHeader(http.StatusInternalServerError) - default: - w.WriteHeader(http.StatusOK) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewTickParentHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "close child issue failed") -} - -func TestDismissReviews_Execute_Bad_ListFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewDismissReviewsHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "list reviews") -} - -func TestDismissReviews_Execute_Good_NothingToDismiss(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet { - // All reviews are either approved or already dismissed. - reviews := []map[string]any{ - { - "id": 1, "state": "APPROVED", "dismissed": false, "stale": false, - "body": "lgtm", "commit_id": "abc123", - }, - { - "id": 2, "state": "REQUEST_CHANGES", "dismissed": true, "stale": true, - "body": "already dismissed", "commit_id": "abc123", - }, - { - "id": 3, "state": "REQUEST_CHANGES", "dismissed": false, "stale": false, - "body": "not stale", "commit_id": "abc123", - }, - } - _ = json.NewEncoder(w).Encode(reviews) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewDismissReviewsHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success, "nothing to dismiss should be success") -} - -func TestDismissReviews_Execute_Bad_DismissFails(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodGet { - reviews := []map[string]any{ - { - "id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, - "body": "fix it", "commit_id": "abc123", - }, - } - _ = json.NewEncoder(w).Encode(reviews) - return - } - - // Dismiss fails. - w.WriteHeader(http.StatusForbidden) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewDismissReviewsHandler(client) - - sig := &jobrunner.PipelineSignal{ - RepoOwner: "org", - RepoName: "repo", - PRNumber: 1, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.False(t, result.Success) - assert.Contains(t, result.Error, "failed to dismiss") -} - -// --- Dispatch Execute edge cases --- - -func TestDispatch_Execute_Good_AlreadyInProgress(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - }) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"}) - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - // Issue already has in-progress label. - _ = json.NewEncoder(w).Encode(map[string]any{ - "id": 5, - "number": 5, - "labels": []map[string]any{{"name": "in-progress", "id": 1}}, - "title": "Test", - }) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success, "already in-progress should be a no-op success") -} - -func TestDispatch_Execute_Good_AlreadyCompleted(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 2, "name": "agent-completed", "color": "#0e8a16"}, - }) - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/org/repo/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "name": "in-progress"}) - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/org/repo/issues/5": - _ = json.NewEncoder(w).Encode(map[string]any{ - "id": 5, - "number": 5, - "labels": []map[string]any{{"name": "agent-completed", "id": 2}}, - "title": "Done", - }) - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org", - RepoName: "repo", - ChildNumber: 5, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - assert.True(t, result.Success) -} - -func TestDispatch_Execute_Bad_InvalidRepoOwner(t *testing.T) { - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - spinner := newTestSpinner(map[string]agentci.AgentConfig{ - "darbs-claude": {Host: "localhost", QueueDir: "/tmp/queue", Active: true}, - }) - h := NewDispatchHandler(client, srv.URL, "test-token", spinner) - - sig := &jobrunner.PipelineSignal{ - NeedsCoding: true, - Assignee: "darbs-claude", - RepoOwner: "org$bad", - RepoName: "repo", - ChildNumber: 1, - } - - _, err := h.Execute(context.Background(), sig) - assert.Error(t, err) - assert.Contains(t, err.Error(), "invalid repo owner") -} diff --git a/pkg/jobrunner/handlers/integration_test.go b/pkg/jobrunner/handlers/integration_test.go deleted file mode 100644 index c1b6379..0000000 --- a/pkg/jobrunner/handlers/integration_test.go +++ /dev/null @@ -1,824 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// --- Integration: full signal -> handler -> result flow --- -// These tests exercise the complete pipeline: a signal is created, -// matched by a handler, executed against a mock Forgejo server, -// and the result is verified. - -// mockForgejoServer creates a comprehensive mock Forgejo API server -// for integration testing. It supports issues, PRs, labels, comments, -// and tracks all API calls made. -type apiCall struct { - Method string - Path string - Body string -} - -type forgejoMock struct { - epicBody string - calls []apiCall - srv *httptest.Server - closedChild bool - editedBody string - comments []string -} - -func newForgejoMock(t *testing.T, epicBody string) *forgejoMock { - t.Helper() - m := &forgejoMock{ - epicBody: epicBody, - } - - m.srv = httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - bodyBytes, _ := io.ReadAll(r.Body) - m.calls = append(m.calls, apiCall{ - Method: r.Method, - Path: r.URL.Path, - Body: string(bodyBytes), - }) - - w.Header().Set("Content-Type", "application/json") - path := r.URL.Path - - switch { - // GET epic issue. - case r.Method == http.MethodGet && strings.Contains(path, "/issues/") && !strings.Contains(path, "/comments"): - issueNum := path[strings.LastIndex(path, "/")+1:] - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": json.Number(issueNum), - "body": m.epicBody, - "title": "Epic: Phase 3", - "state": "open", - "labels": []map[string]any{{"name": "epic", "id": 1}}, - }) - - // PATCH epic issue (edit body or close child). - case r.Method == http.MethodPatch && strings.Contains(path, "/issues/"): - var body map[string]any - _ = json.Unmarshal(bodyBytes, &body) - - if bodyStr, ok := body["body"].(string); ok { - m.editedBody = bodyStr - } - if state, ok := body["state"].(string); ok && state == "closed" { - m.closedChild = true - } - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 1, - "body": m.editedBody, - "state": "open", - }) - - // POST comment. - case r.Method == http.MethodPost && strings.Contains(path, "/comments"): - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - m.comments = append(m.comments, body["body"]) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1, "body": body["body"]}) - - // GET labels. - case r.Method == http.MethodGet && strings.Contains(path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "epic", "color": "#ff0000"}, - {"id": 2, "name": "in-progress", "color": "#1d76db"}, - }) - - // POST labels. - case r.Method == http.MethodPost && strings.HasSuffix(path, "/labels"): - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 10, "name": "new-label"}) - - // POST issue labels. - case r.Method == http.MethodPost && strings.Contains(path, "/issues/") && strings.Contains(path, "/labels"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - - // DELETE issue label. - case r.Method == http.MethodDelete && strings.Contains(path, "/labels/"): - w.WriteHeader(http.StatusNoContent) - - // POST merge PR. - case r.Method == http.MethodPost && strings.Contains(path, "/merge"): - w.WriteHeader(http.StatusOK) - - // PATCH PR (publish draft). - case r.Method == http.MethodPatch && strings.Contains(path, "/pulls/"): - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{}`)) - - // GET reviews. - case r.Method == http.MethodGet && strings.Contains(path, "/reviews"): - _ = json.NewEncoder(w).Encode([]map[string]any{}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - - return m -} - -func (m *forgejoMock) close() { - m.srv.Close() -} - -func (m *forgejoMock) client(t *testing.T) *forge.Client { - t.Helper() - c, err := forge.New(m.srv.URL, "test-token") - require.NoError(t, err) - return c -} - -// --- TickParent integration: signal -> execute -> verify epic updated --- - -func TestIntegration_TickParent_Good_FullFlow(t *testing.T) { - epicBody := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n- [x] #3\n" - - mock := newForgejoMock(t, epicBody) - defer mock.close() - - h := NewTickParentHandler(mock.client(t)) - - // Create signal representing a merged PR for child #7. - signal := &jobrunner.PipelineSignal{ - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - RepoOwner: "host-uk", - RepoName: "core-php", - PRState: "MERGED", - CheckStatus: "SUCCESS", - Mergeable: "UNKNOWN", - } - - // Verify the handler matches. - assert.True(t, h.Match(signal)) - - // Execute. - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - // Verify result. - assert.True(t, result.Success) - assert.Equal(t, "tick_parent", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core-php", result.RepoName) - assert.Equal(t, 99, result.PRNumber) - - // Verify the epic body was updated: #7 should now be checked. - assert.Contains(t, mock.editedBody, "- [x] #7") - // #8 should still be unchecked. - assert.Contains(t, mock.editedBody, "- [ ] #8") - // #1 and #3 should remain checked. - assert.Contains(t, mock.editedBody, "- [x] #1") - assert.Contains(t, mock.editedBody, "- [x] #3") - - // Verify the child issue was closed. - assert.True(t, mock.closedChild) -} - -// --- TickParent integration: epic progress tracking --- - -func TestIntegration_TickParent_Good_TrackEpicProgress(t *testing.T) { - // Start with 4 tasks, 1 checked. - epicBody := "## Tasks\n- [x] #1\n- [ ] #2\n- [ ] #3\n- [ ] #4\n" - - mock := newForgejoMock(t, epicBody) - defer mock.close() - - h := NewTickParentHandler(mock.client(t)) - - // Tick child #2. - signal := &jobrunner.PipelineSignal{ - EpicNumber: 10, - ChildNumber: 2, - PRNumber: 20, - RepoOwner: "org", - RepoName: "repo", - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - assert.True(t, result.Success) - - // Verify #2 is now checked. - assert.Contains(t, mock.editedBody, "- [x] #2") - // #3 and #4 should still be unchecked. - assert.Contains(t, mock.editedBody, "- [ ] #3") - assert.Contains(t, mock.editedBody, "- [ ] #4") - - // Count progress: 2 out of 4 now checked. - checked := strings.Count(mock.editedBody, "- [x]") - unchecked := strings.Count(mock.editedBody, "- [ ]") - assert.Equal(t, 2, checked) - assert.Equal(t, 2, unchecked) -} - -// --- EnableAutoMerge integration: full flow --- - -func TestIntegration_EnableAutoMerge_Good_FullFlow(t *testing.T) { - var mergeMethod string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/merge") { - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]any - _ = json.Unmarshal(bodyBytes, &body) - if do, ok := body["Do"].(string); ok { - mergeMethod = do - } - w.WriteHeader(http.StatusOK) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewEnableAutoMergeHandler(client) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 5, - PRNumber: 42, - RepoOwner: "host-uk", - RepoName: "core-tenant", - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - } - - // Verify match. - assert.True(t, h.Match(signal)) - - // Execute. - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "enable_auto_merge", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core-tenant", result.RepoName) - assert.Equal(t, 42, result.PRNumber) - assert.Equal(t, "squash", mergeMethod) -} - -// --- PublishDraft integration: full flow --- - -func TestIntegration_PublishDraft_Good_FullFlow(t *testing.T) { - var patchedDraft bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPatch && strings.Contains(r.URL.Path, "/pulls/") { - bodyBytes, _ := io.ReadAll(r.Body) - if strings.Contains(string(bodyBytes), `"draft":false`) { - patchedDraft = true - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{}`)) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewPublishDraftHandler(client) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 3, - ChildNumber: 8, - PRNumber: 15, - RepoOwner: "org", - RepoName: "repo", - PRState: "OPEN", - IsDraft: true, - CheckStatus: "SUCCESS", - Mergeable: "MERGEABLE", - } - - // Verify match. - assert.True(t, h.Match(signal)) - - // Execute. - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "publish_draft", result.Action) - assert.True(t, patchedDraft) -} - -// --- SendFixCommand integration: conflict message --- - -func TestIntegration_SendFixCommand_Good_ConflictFlow(t *testing.T) { - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments") { - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - commentBody = body["body"] - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewSendFixCommandHandler(client) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 3, - PRNumber: 10, - RepoOwner: "org", - RepoName: "repo", - PRState: "OPEN", - Mergeable: "CONFLICTING", - CheckStatus: "SUCCESS", - } - - assert.True(t, h.Match(signal)) - - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "send_fix_command", result.Action) - assert.Contains(t, commentBody, "fix the merge conflict") -} - -// --- SendFixCommand integration: code review message --- - -func TestIntegration_SendFixCommand_Good_ReviewFlow(t *testing.T) { - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - if r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments") { - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - commentBody = body["body"] - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - return - } - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewSendFixCommandHandler(client) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 1, - ChildNumber: 3, - PRNumber: 10, - RepoOwner: "org", - RepoName: "repo", - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "FAILURE", - ThreadsTotal: 3, - ThreadsResolved: 1, - } - - assert.True(t, h.Match(signal)) - - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Contains(t, commentBody, "fix the code reviews") -} - -// --- Completion integration: success flow --- - -func TestIntegration_Completion_Good_SuccessFlow(t *testing.T) { - var labelAdded bool - var labelRemoved bool - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - // GetLabelByName — GET repo labels. - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/core/go-scm/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - }) - - // RemoveIssueLabel. - case r.Method == http.MethodDelete && strings.Contains(r.URL.Path, "/labels/"): - labelRemoved = true - w.WriteHeader(http.StatusNoContent) - - // EnsureLabel — POST to create repo label. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 2, "name": "agent-completed", "color": "#0e8a16"}) - - // AddIssueLabels — POST to issue labels. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/issues/12/labels": - labelAdded = true - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 2, "name": "agent-completed"}}) - - // CreateIssueComment. - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"): - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - commentBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - signal := &jobrunner.PipelineSignal{ - Type: "agent_completion", - EpicNumber: 5, - ChildNumber: 12, - RepoOwner: "core", - RepoName: "go-scm", - Success: true, - Message: "PR created and tests passing", - } - - assert.True(t, h.Match(signal)) - - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "completion", result.Action) - assert.Equal(t, "core", result.RepoOwner) - assert.Equal(t, "go-scm", result.RepoName) - assert.Equal(t, 5, result.EpicNumber) - assert.Equal(t, 12, result.ChildNumber) - assert.True(t, labelRemoved, "in-progress label should be removed") - assert.True(t, labelAdded, "agent-completed label should be added") - assert.Contains(t, commentBody, "PR created and tests passing") -} - -// --- Full pipeline integration: signal -> match -> execute -> journal --- - -func TestIntegration_FullPipeline_Good_TickParentWithJournal(t *testing.T) { - epicBody := "## Tasks\n- [ ] #7\n- [ ] #8\n" - - mock := newForgejoMock(t, epicBody) - defer mock.close() - - dir := t.TempDir() - journal, err := jobrunner.NewJournal(dir) - require.NoError(t, err) - - client := mock.client(t) - h := NewTickParentHandler(client) - - signal := &jobrunner.PipelineSignal{ - EpicNumber: 10, - ChildNumber: 7, - PRNumber: 55, - RepoOwner: "host-uk", - RepoName: "core-tenant", - PRState: "MERGED", - CheckStatus: "SUCCESS", - Mergeable: "UNKNOWN", - } - - // Verify match. - assert.True(t, h.Match(signal)) - - // Execute. - start := time.Now() - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) - - // Write to journal (simulating what the poller does). - result.EpicNumber = signal.EpicNumber - result.ChildNumber = signal.ChildNumber - result.Cycle = 1 - result.Duration = time.Since(start) - - err = journal.Append(signal, result) - require.NoError(t, err) - - // Verify the journal file exists and contains the entry. - date := time.Now().UTC().Format("2006-01-02") - journalPath := filepath.Join(dir, "host-uk", "core-tenant", date+".jsonl") - - _, statErr := os.Stat(journalPath) - require.NoError(t, statErr) - - f, err := os.Open(journalPath) - require.NoError(t, err) - defer func() { _ = f.Close() }() - - var entry jobrunner.JournalEntry - err = json.NewDecoder(f).Decode(&entry) - require.NoError(t, err) - - assert.Equal(t, "tick_parent", entry.Action) - assert.Equal(t, "host-uk/core-tenant", entry.Repo) - assert.Equal(t, 10, entry.Epic) - assert.Equal(t, 7, entry.Child) - assert.Equal(t, 55, entry.PR) - assert.Equal(t, 1, entry.Cycle) - assert.True(t, entry.Result.Success) - assert.Equal(t, "MERGED", entry.Signals.PRState) - - // Verify the epic was properly updated. - assert.Contains(t, mock.editedBody, "- [x] #7") - assert.Contains(t, mock.editedBody, "- [ ] #8") - assert.True(t, mock.closedChild) -} - -// --- Handler matching priority: first match wins --- - -func TestIntegration_HandlerPriority_Good_FirstMatchWins(t *testing.T) { - // Test that when multiple handlers could match, the first one wins. - // This exercises the poller's findHandler logic. - - // Signal with OPEN, not draft, MERGEABLE, SUCCESS, no threads: - // This matches enable_auto_merge. - signal := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - - autoMerge := NewEnableAutoMergeHandler(nil) - publishDraft := NewPublishDraftHandler(nil) - fixCommand := NewSendFixCommandHandler(nil) - - // enable_auto_merge should match. - assert.True(t, autoMerge.Match(signal)) - // publish_draft should NOT match (not a draft). - assert.False(t, publishDraft.Match(signal)) - // send_fix_command should NOT match (mergeable and passing). - assert.False(t, fixCommand.Match(signal)) -} - -// --- Handler matching: draft PR path --- - -func TestIntegration_HandlerPriority_Good_DraftPRPath(t *testing.T) { - signal := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: true, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - - autoMerge := NewEnableAutoMergeHandler(nil) - publishDraft := NewPublishDraftHandler(nil) - fixCommand := NewSendFixCommandHandler(nil) - - // enable_auto_merge should NOT match (is draft). - assert.False(t, autoMerge.Match(signal)) - // publish_draft should match (draft + open + success). - assert.True(t, publishDraft.Match(signal)) - // send_fix_command should NOT match. - assert.False(t, fixCommand.Match(signal)) -} - -// --- Handler matching: merged PR only matches tick_parent --- - -func TestIntegration_HandlerPriority_Good_MergedPRPath(t *testing.T) { - signal := &jobrunner.PipelineSignal{ - PRState: "MERGED", - IsDraft: false, - Mergeable: "UNKNOWN", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - - autoMerge := NewEnableAutoMergeHandler(nil) - publishDraft := NewPublishDraftHandler(nil) - fixCommand := NewSendFixCommandHandler(nil) - tickParent := NewTickParentHandler(nil) - - assert.False(t, autoMerge.Match(signal)) - assert.False(t, publishDraft.Match(signal)) - assert.False(t, fixCommand.Match(signal)) - assert.True(t, tickParent.Match(signal)) -} - -// --- Handler matching: conflicting PR matches send_fix_command --- - -func TestIntegration_HandlerPriority_Good_ConflictingPRPath(t *testing.T) { - signal := &jobrunner.PipelineSignal{ - PRState: "OPEN", - IsDraft: false, - Mergeable: "CONFLICTING", - CheckStatus: "SUCCESS", - ThreadsTotal: 0, - ThreadsResolved: 0, - } - - autoMerge := NewEnableAutoMergeHandler(nil) - fixCommand := NewSendFixCommandHandler(nil) - - // enable_auto_merge should NOT match (conflicting). - assert.False(t, autoMerge.Match(signal)) - // send_fix_command should match (conflicting). - assert.True(t, fixCommand.Match(signal)) -} - -// --- Completion integration: failure flow --- - -func TestIntegration_Completion_Good_FailureFlow(t *testing.T) { - var commentBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - // GetLabelByName — GET repo labels. - case r.Method == http.MethodGet && r.URL.Path == "/api/v1/repos/core/go-scm/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{ - {"id": 1, "name": "in-progress", "color": "#1d76db"}, - }) - - // RemoveIssueLabel. - case r.Method == http.MethodDelete: - w.WriteHeader(http.StatusNoContent) - - // EnsureLabel — POST to create repo label. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/labels": - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 3, "name": "agent-failed", "color": "#c0392b"}) - - // AddIssueLabels — POST to issue labels. - case r.Method == http.MethodPost && r.URL.Path == "/api/v1/repos/core/go-scm/issues/12/labels": - _ = json.NewEncoder(w).Encode([]map[string]any{{"id": 3, "name": "agent-failed"}}) - - // CreateIssueComment. - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"): - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - commentBody = body["body"] - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - h := NewCompletionHandler(client) - - signal := &jobrunner.PipelineSignal{ - Type: "agent_completion", - EpicNumber: 5, - ChildNumber: 12, - RepoOwner: "core", - RepoName: "go-scm", - Success: false, - Error: "tests failed: 3 assertions", - } - - result, err := h.Execute(context.Background(), signal) - require.NoError(t, err) - - assert.True(t, result.Success) // The handler itself succeeded. - assert.Contains(t, commentBody, "Agent reported failure") - assert.Contains(t, commentBody, "tests failed: 3 assertions") -} - -// --- Multiple handlers execute in sequence for different signals --- - -func TestIntegration_MultipleHandlers_Good_DifferentSignals(t *testing.T) { - var commentBodies []string - var mergedPRs []int64 - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch { - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/merge"): - // Extract PR number from path. - parts := strings.Split(r.URL.Path, "/") - for i, p := range parts { - if p == "pulls" && i+1 < len(parts) { - var prNum int64 - _ = json.Unmarshal([]byte(parts[i+1]), &prNum) - mergedPRs = append(mergedPRs, prNum) - } - } - w.WriteHeader(http.StatusOK) - - case r.Method == http.MethodPost && strings.Contains(r.URL.Path, "/comments"): - bodyBytes, _ := io.ReadAll(r.Body) - var body map[string]string - _ = json.Unmarshal(bodyBytes, &body) - commentBodies = append(commentBodies, body["body"]) - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(map[string]any{"id": 1}) - - case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/issues/"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": "## Tasks\n- [ ] #7\n- [ ] #8\n", - "title": "Epic", - }) - - case r.Method == http.MethodPatch: - _ = json.NewEncoder(w).Encode(map[string]any{"number": 1, "body": "", "state": "open"}) - - default: - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]any{}) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - autoMergeHandler := NewEnableAutoMergeHandler(client) - fixCommandHandler := NewSendFixCommandHandler(client) - - // Signal 1: should trigger auto merge. - sig1 := &jobrunner.PipelineSignal{ - PRState: "OPEN", IsDraft: false, Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", PRNumber: 10, - RepoOwner: "org", RepoName: "repo", - } - - // Signal 2: should trigger fix command. - sig2 := &jobrunner.PipelineSignal{ - PRState: "OPEN", Mergeable: "CONFLICTING", - CheckStatus: "SUCCESS", PRNumber: 20, - RepoOwner: "org", RepoName: "repo", - } - - assert.True(t, autoMergeHandler.Match(sig1)) - assert.False(t, autoMergeHandler.Match(sig2)) - - assert.False(t, fixCommandHandler.Match(sig1)) - assert.True(t, fixCommandHandler.Match(sig2)) - - // Execute both. - result1, err := autoMergeHandler.Execute(context.Background(), sig1) - require.NoError(t, err) - assert.True(t, result1.Success) - - result2, err := fixCommandHandler.Execute(context.Background(), sig2) - require.NoError(t, err) - assert.True(t, result2.Success) - - // Verify correct comment was posted for the conflicting PR. - require.Len(t, commentBodies, 1) - assert.Contains(t, commentBodies[0], "fix the merge conflict") -} diff --git a/pkg/jobrunner/handlers/publish_draft.go b/pkg/jobrunner/handlers/publish_draft.go deleted file mode 100644 index b75dc51..0000000 --- a/pkg/jobrunner/handlers/publish_draft.go +++ /dev/null @@ -1,55 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "time" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// PublishDraftHandler marks a draft PR as ready for review once its checks pass. -type PublishDraftHandler struct { - forge *forge.Client -} - -// NewPublishDraftHandler creates a handler that publishes draft PRs. -func NewPublishDraftHandler(f *forge.Client) *PublishDraftHandler { - return &PublishDraftHandler{forge: f} -} - -// Name returns the handler identifier. -func (h *PublishDraftHandler) Name() string { - return "publish_draft" -} - -// Match returns true when the PR is a draft, open, and all checks have passed. -func (h *PublishDraftHandler) Match(signal *jobrunner.PipelineSignal) bool { - return signal.IsDraft && - signal.PRState == "OPEN" && - signal.CheckStatus == "SUCCESS" -} - -// Execute marks the PR as no longer a draft. -func (h *PublishDraftHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - err := h.forge.SetPRDraft(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber), false) - - result := &jobrunner.ActionResult{ - Action: "publish_draft", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: err == nil, - Timestamp: time.Now(), - Duration: time.Since(start), - } - - if err != nil { - result.Error = fmt.Sprintf("publish draft failed: %v", err) - } - - return result, nil -} diff --git a/pkg/jobrunner/handlers/publish_draft_test.go b/pkg/jobrunner/handlers/publish_draft_test.go deleted file mode 100644 index 1ecf84f..0000000 --- a/pkg/jobrunner/handlers/publish_draft_test.go +++ /dev/null @@ -1,84 +0,0 @@ -package handlers - -import ( - "context" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestPublishDraft_Match_Good(t *testing.T) { - h := NewPublishDraftHandler(nil) - sig := &jobrunner.PipelineSignal{ - IsDraft: true, - PRState: "OPEN", - CheckStatus: "SUCCESS", - } - assert.True(t, h.Match(sig)) -} - -func TestPublishDraft_Match_Bad_NotDraft(t *testing.T) { - h := NewPublishDraftHandler(nil) - sig := &jobrunner.PipelineSignal{ - IsDraft: false, - PRState: "OPEN", - CheckStatus: "SUCCESS", - } - assert.False(t, h.Match(sig)) -} - -func TestPublishDraft_Match_Bad_ChecksFailing(t *testing.T) { - h := NewPublishDraftHandler(nil) - sig := &jobrunner.PipelineSignal{ - IsDraft: true, - PRState: "OPEN", - CheckStatus: "FAILURE", - } - assert.False(t, h.Match(sig)) -} - -func TestPublishDraft_Execute_Good(t *testing.T) { - var capturedMethod string - var capturedPath string - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedMethod = r.Method - capturedPath = r.URL.Path - b, _ := io.ReadAll(r.Body) - capturedBody = string(b) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(`{}`)) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewPublishDraftHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - PRNumber: 42, - IsDraft: true, - PRState: "OPEN", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.Equal(t, http.MethodPatch, capturedMethod) - assert.Equal(t, "/api/v1/repos/host-uk/core-php/pulls/42", capturedPath) - assert.Contains(t, capturedBody, `"draft":false`) - - assert.True(t, result.Success) - assert.Equal(t, "publish_draft", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core-php", result.RepoName) - assert.Equal(t, 42, result.PRNumber) -} diff --git a/pkg/jobrunner/handlers/resolve_threads.go b/pkg/jobrunner/handlers/resolve_threads.go deleted file mode 100644 index 22c900f..0000000 --- a/pkg/jobrunner/handlers/resolve_threads.go +++ /dev/null @@ -1,80 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "time" - - forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// DismissReviewsHandler dismisses stale "request changes" reviews on a PR. -// This replaces the GitHub-only ResolveThreadsHandler because Forgejo does -// not have a thread resolution API. -type DismissReviewsHandler struct { - forge *forge.Client -} - -// NewDismissReviewsHandler creates a handler that dismisses stale reviews. -func NewDismissReviewsHandler(f *forge.Client) *DismissReviewsHandler { - return &DismissReviewsHandler{forge: f} -} - -// Name returns the handler identifier. -func (h *DismissReviewsHandler) Name() string { - return "dismiss_reviews" -} - -// Match returns true when the PR is open and has unresolved review threads. -func (h *DismissReviewsHandler) Match(signal *jobrunner.PipelineSignal) bool { - return signal.PRState == "OPEN" && signal.HasUnresolvedThreads() -} - -// Execute dismisses stale "request changes" reviews on the PR. -func (h *DismissReviewsHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - reviews, err := h.forge.ListPRReviews(signal.RepoOwner, signal.RepoName, int64(signal.PRNumber)) - if err != nil { - return nil, coreerr.E("dismiss_reviews.Execute", "list reviews", err) - } - - var dismissErrors []string - dismissed := 0 - for _, review := range reviews { - if review.State != forgejosdk.ReviewStateRequestChanges || review.Dismissed || !review.Stale { - continue - } - - if err := h.forge.DismissReview( - signal.RepoOwner, signal.RepoName, - int64(signal.PRNumber), review.ID, - "Automatically dismissed: review is stale after new commits", - ); err != nil { - dismissErrors = append(dismissErrors, err.Error()) - } else { - dismissed++ - } - } - - result := &jobrunner.ActionResult{ - Action: "dismiss_reviews", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: len(dismissErrors) == 0, - Timestamp: time.Now(), - Duration: time.Since(start), - } - - if len(dismissErrors) > 0 { - result.Error = fmt.Sprintf("failed to dismiss %d review(s): %s", - len(dismissErrors), dismissErrors[0]) - } - - return result, nil -} diff --git a/pkg/jobrunner/handlers/resolve_threads_test.go b/pkg/jobrunner/handlers/resolve_threads_test.go deleted file mode 100644 index d5d16e8..0000000 --- a/pkg/jobrunner/handlers/resolve_threads_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestDismissReviews_Match_Good(t *testing.T) { - h := NewDismissReviewsHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - ThreadsTotal: 4, - ThreadsResolved: 2, - } - assert.True(t, h.Match(sig)) -} - -func TestDismissReviews_Match_Bad_AllResolved(t *testing.T) { - h := NewDismissReviewsHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - ThreadsTotal: 3, - ThreadsResolved: 3, - } - assert.False(t, h.Match(sig)) -} - -func TestDismissReviews_Execute_Good(t *testing.T) { - callCount := 0 - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - callCount++ - w.Header().Set("Content-Type", "application/json") - - // ListPullReviews (GET) - if r.Method == http.MethodGet { - reviews := []map[string]any{ - { - "id": 1, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, - "body": "fix this", "commit_id": "abc123", - }, - { - "id": 2, "state": "APPROVED", "dismissed": false, "stale": false, - "body": "looks good", "commit_id": "abc123", - }, - { - "id": 3, "state": "REQUEST_CHANGES", "dismissed": false, "stale": true, - "body": "needs work", "commit_id": "abc123", - }, - } - _ = json.NewEncoder(w).Encode(reviews) - return - } - - // DismissPullReview (POST to dismissals endpoint) - w.WriteHeader(http.StatusOK) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewDismissReviewsHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-admin", - PRNumber: 33, - PRState: "OPEN", - ThreadsTotal: 3, - ThreadsResolved: 1, - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "dismiss_reviews", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core-admin", result.RepoName) - assert.Equal(t, 33, result.PRNumber) - - // 1 list + 2 dismiss (reviews #1 and #3 are stale REQUEST_CHANGES) - assert.Equal(t, 3, callCount) -} diff --git a/pkg/jobrunner/handlers/send_fix_command.go b/pkg/jobrunner/handlers/send_fix_command.go deleted file mode 100644 index 465eccd..0000000 --- a/pkg/jobrunner/handlers/send_fix_command.go +++ /dev/null @@ -1,74 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "time" - - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// SendFixCommandHandler posts a comment on a PR asking for conflict or -// review fixes. -type SendFixCommandHandler struct { - forge *forge.Client -} - -// NewSendFixCommandHandler creates a handler that posts fix commands. -func NewSendFixCommandHandler(f *forge.Client) *SendFixCommandHandler { - return &SendFixCommandHandler{forge: f} -} - -// Name returns the handler identifier. -func (h *SendFixCommandHandler) Name() string { - return "send_fix_command" -} - -// Match returns true when the PR is open and either has merge conflicts or -// has unresolved threads with failing checks. -func (h *SendFixCommandHandler) Match(signal *jobrunner.PipelineSignal) bool { - if signal.PRState != "OPEN" { - return false - } - if signal.Mergeable == "CONFLICTING" { - return true - } - if signal.HasUnresolvedThreads() && signal.CheckStatus == "FAILURE" { - return true - } - return false -} - -// Execute posts a comment on the PR asking for a fix. -func (h *SendFixCommandHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - var message string - if signal.Mergeable == "CONFLICTING" { - message = "Can you fix the merge conflict?" - } else { - message = "Can you fix the code reviews?" - } - - err := h.forge.CreateIssueComment( - signal.RepoOwner, signal.RepoName, - int64(signal.PRNumber), message, - ) - - result := &jobrunner.ActionResult{ - Action: "send_fix_command", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: err == nil, - Timestamp: time.Now(), - Duration: time.Since(start), - } - - if err != nil { - result.Error = fmt.Sprintf("post comment failed: %v", err) - } - - return result, nil -} diff --git a/pkg/jobrunner/handlers/send_fix_command_test.go b/pkg/jobrunner/handlers/send_fix_command_test.go deleted file mode 100644 index b2002d0..0000000 --- a/pkg/jobrunner/handlers/send_fix_command_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package handlers - -import ( - "context" - "io" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestSendFixCommand_Match_Good_Conflicting(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "CONFLICTING", - } - assert.True(t, h.Match(sig)) -} - -func TestSendFixCommand_Match_Good_UnresolvedThreads(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "FAILURE", - ThreadsTotal: 3, - ThreadsResolved: 1, - } - assert.True(t, h.Match(sig)) -} - -func TestSendFixCommand_Match_Bad_Clean(t *testing.T) { - h := NewSendFixCommandHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 2, - ThreadsResolved: 2, - } - assert.False(t, h.Match(sig)) -} - -func TestSendFixCommand_Execute_Good_Conflict(t *testing.T) { - var capturedMethod string - var capturedPath string - var capturedBody string - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - capturedMethod = r.Method - capturedPath = r.URL.Path - b, _ := io.ReadAll(r.Body) - capturedBody = string(b) - w.WriteHeader(http.StatusCreated) - _, _ = w.Write([]byte(`{"id":1}`)) - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewSendFixCommandHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-tenant", - PRNumber: 17, - PRState: "OPEN", - Mergeable: "CONFLICTING", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.Equal(t, http.MethodPost, capturedMethod) - assert.Equal(t, "/api/v1/repos/host-uk/core-tenant/issues/17/comments", capturedPath) - assert.Contains(t, capturedBody, "fix the merge conflict") - - assert.True(t, result.Success) - assert.Equal(t, "send_fix_command", result.Action) - assert.Equal(t, "host-uk", result.RepoOwner) - assert.Equal(t, "core-tenant", result.RepoName) - assert.Equal(t, 17, result.PRNumber) -} diff --git a/pkg/jobrunner/handlers/testhelper_test.go b/pkg/jobrunner/handlers/testhelper_test.go deleted file mode 100644 index 277591c..0000000 --- a/pkg/jobrunner/handlers/testhelper_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package handlers - -import ( - "net/http" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/go-scm/forge" -) - -// forgejoVersionResponse is the JSON response for /api/v1/version. -const forgejoVersionResponse = `{"version":"9.0.0"}` - -// withVersion wraps an HTTP handler to also serve the Forgejo version endpoint -// that the SDK calls during NewClient initialization. -func withVersion(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.HasSuffix(r.URL.Path, "/version") { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(forgejoVersionResponse)) - return - } - next.ServeHTTP(w, r) - }) -} - -// newTestForgeClient creates a forge.Client pointing at the given test server URL. -func newTestForgeClient(t *testing.T, url string) *forge.Client { - t.Helper() - client, err := forge.New(url, "test-token") - require.NoError(t, err) - return client -} diff --git a/pkg/jobrunner/handlers/tick_parent.go b/pkg/jobrunner/handlers/tick_parent.go deleted file mode 100644 index d090f54..0000000 --- a/pkg/jobrunner/handlers/tick_parent.go +++ /dev/null @@ -1,101 +0,0 @@ -package handlers - -import ( - "context" - "fmt" - "strings" - "time" - - forgejosdk "codeberg.org/mvdkleijn/forgejo-sdk/forgejo/v2" - - coreerr "forge.lthn.ai/core/go-log" - "forge.lthn.ai/core/go-scm/forge" - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// TickParentHandler ticks a child checkbox in the parent epic issue body -// after the child's PR has been merged. -type TickParentHandler struct { - forge *forge.Client -} - -// NewTickParentHandler creates a handler that ticks parent epic checkboxes. -func NewTickParentHandler(f *forge.Client) *TickParentHandler { - return &TickParentHandler{forge: f} -} - -// Name returns the handler identifier. -func (h *TickParentHandler) Name() string { - return "tick_parent" -} - -// Match returns true when the child PR has been merged. -func (h *TickParentHandler) Match(signal *jobrunner.PipelineSignal) bool { - return signal.PRState == "MERGED" -} - -// Execute fetches the epic body, replaces the unchecked checkbox for the -// child issue with a checked one, updates the epic, and closes the child issue. -func (h *TickParentHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { - start := time.Now() - - // Fetch the epic issue body. - epic, err := h.forge.GetIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber)) - if err != nil { - return nil, coreerr.E("tick_parent.Execute", "fetch epic", err) - } - - oldBody := epic.Body - unchecked := fmt.Sprintf("- [ ] #%d", signal.ChildNumber) - checked := fmt.Sprintf("- [x] #%d", signal.ChildNumber) - - if !strings.Contains(oldBody, unchecked) { - // Already ticked or not found -- nothing to do. - return &jobrunner.ActionResult{ - Action: "tick_parent", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: true, - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - - newBody := strings.Replace(oldBody, unchecked, checked, 1) - - // Update the epic body. - _, err = h.forge.EditIssue(signal.RepoOwner, signal.RepoName, int64(signal.EpicNumber), forgejosdk.EditIssueOption{ - Body: &newBody, - }) - if err != nil { - return &jobrunner.ActionResult{ - Action: "tick_parent", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Error: fmt.Sprintf("edit epic failed: %v", err), - Timestamp: time.Now(), - Duration: time.Since(start), - }, nil - } - - // Close the child issue. - err = h.forge.CloseIssue(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber)) - - result := &jobrunner.ActionResult{ - Action: "tick_parent", - RepoOwner: signal.RepoOwner, - RepoName: signal.RepoName, - PRNumber: signal.PRNumber, - Success: err == nil, - Timestamp: time.Now(), - Duration: time.Since(start), - } - - if err != nil { - result.Error = fmt.Sprintf("close child issue failed: %v", err) - } - - return result, nil -} diff --git a/pkg/jobrunner/handlers/tick_parent_test.go b/pkg/jobrunner/handlers/tick_parent_test.go deleted file mode 100644 index 2770bfc..0000000 --- a/pkg/jobrunner/handlers/tick_parent_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package handlers - -import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -func TestTickParent_Match_Good(t *testing.T) { - h := NewTickParentHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "MERGED", - } - assert.True(t, h.Match(sig)) -} - -func TestTickParent_Match_Bad_Open(t *testing.T) { - h := NewTickParentHandler(nil) - sig := &jobrunner.PipelineSignal{ - PRState: "OPEN", - } - assert.False(t, h.Match(sig)) -} - -func TestTickParent_Execute_Good(t *testing.T) { - epicBody := "## Tasks\n- [x] #1\n- [ ] #7\n- [ ] #8\n" - var editBody string - var closeCalled bool - - srv := httptest.NewServer(withVersion(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path - method := r.Method - w.Header().Set("Content-Type", "application/json") - - switch { - // GET issue (fetch epic) - case method == http.MethodGet && strings.Contains(path, "/issues/42"): - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": epicBody, - "title": "Epic", - }) - - // PATCH issue (edit epic body) - case method == http.MethodPatch && strings.Contains(path, "/issues/42"): - b, _ := io.ReadAll(r.Body) - editBody = string(b) - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 42, - "body": editBody, - "title": "Epic", - }) - - // PATCH issue (close child — state: closed) - case method == http.MethodPatch && strings.Contains(path, "/issues/7"): - closeCalled = true - _ = json.NewEncoder(w).Encode(map[string]any{ - "number": 7, - "state": "closed", - }) - - default: - w.WriteHeader(http.StatusNotFound) - } - }))) - defer srv.Close() - - client := newTestForgeClient(t, srv.URL) - - h := NewTickParentHandler(client) - sig := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - PRState: "MERGED", - } - - result, err := h.Execute(context.Background(), sig) - require.NoError(t, err) - - assert.True(t, result.Success) - assert.Equal(t, "tick_parent", result.Action) - - // Verify the edit body contains the checked checkbox. - assert.Contains(t, editBody, "- [x] #7") - assert.True(t, closeCalled, "expected child issue to be closed") -} diff --git a/pkg/jobrunner/journal.go b/pkg/jobrunner/journal.go deleted file mode 100644 index 25f162a..0000000 --- a/pkg/jobrunner/journal.go +++ /dev/null @@ -1,203 +0,0 @@ -package jobrunner - -import ( - "bufio" - "encoding/json" - "iter" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - coreio "forge.lthn.ai/core/go-io" - coreerr "forge.lthn.ai/core/go-log" -) - -// validPathComponent matches safe repo owner/name characters (alphanumeric, hyphen, underscore, dot). -var validPathComponent = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9._-]*$`) - -// JournalEntry is a single line in the JSONL audit log. -type JournalEntry struct { - Timestamp string `json:"ts"` - Epic int `json:"epic"` - Child int `json:"child"` - PR int `json:"pr"` - Repo string `json:"repo"` - Action string `json:"action"` - Signals SignalSnapshot `json:"signals"` - Result ResultSnapshot `json:"result"` - Cycle int `json:"cycle"` -} - -// SignalSnapshot captures the structural state of a PR at the time of action. -type SignalSnapshot struct { - PRState string `json:"pr_state"` - IsDraft bool `json:"is_draft"` - CheckStatus string `json:"check_status"` - Mergeable string `json:"mergeable"` - ThreadsTotal int `json:"threads_total"` - ThreadsResolved int `json:"threads_resolved"` -} - -// ResultSnapshot captures the outcome of an action. -type ResultSnapshot struct { - Success bool `json:"success"` - Error string `json:"error,omitempty"` - DurationMs int64 `json:"duration_ms"` -} - -// Journal writes ActionResult entries to date-partitioned JSONL files. -type Journal struct { - baseDir string - mu sync.Mutex -} - -// NewJournal creates a new Journal rooted at baseDir. -func NewJournal(baseDir string) (*Journal, error) { - if baseDir == "" { - return nil, coreerr.E("journal.NewJournal", "base directory is required", nil) - } - return &Journal{baseDir: baseDir}, nil -} - -// sanitizePathComponent validates a single path component (owner or repo name) -// to prevent path traversal attacks. It rejects "..", empty strings, paths -// containing separators, and any value outside the safe character set. -func sanitizePathComponent(name string) (string, error) { - // Reject empty or whitespace-only values. - if name == "" || strings.TrimSpace(name) == "" { - return "", coreerr.E("journal.sanitizePathComponent", "invalid path component: "+name, nil) - } - - // Reject inputs containing path separators (directory traversal attempt). - if strings.ContainsAny(name, `/\`) { - return "", coreerr.E("journal.sanitizePathComponent", "path component contains directory separator: "+name, nil) - } - - // Use filepath.Clean to normalize (e.g., collapse redundant dots). - clean := filepath.Clean(name) - - // Reject traversal components. - if clean == "." || clean == ".." { - return "", coreerr.E("journal.sanitizePathComponent", "invalid path component: "+name, nil) - } - - // Validate against the safe character set. - if !validPathComponent.MatchString(clean) { - return "", coreerr.E("journal.sanitizePathComponent", "path component contains invalid characters: "+name, nil) - } - - return clean, nil -} - -// ReadEntries returns an iterator over JournalEntry lines in a date-partitioned file. -func (j *Journal) ReadEntries(path string) iter.Seq2[JournalEntry, error] { - return func(yield func(JournalEntry, error) bool) { - f, err := os.Open(path) - if err != nil { - yield(JournalEntry{}, err) - return - } - defer func() { _ = f.Close() }() - - scanner := bufio.NewScanner(f) - for scanner.Scan() { - var entry JournalEntry - if err := json.Unmarshal(scanner.Bytes(), &entry); err != nil { - if !yield(JournalEntry{}, err) { - return - } - continue - } - if !yield(entry, nil) { - return - } - } - if err := scanner.Err(); err != nil { - yield(JournalEntry{}, err) - } - } -} - -// Append writes a journal entry for the given signal and result. -func (j *Journal) Append(signal *PipelineSignal, result *ActionResult) error { - if signal == nil { - return coreerr.E("journal.Append", "signal is required", nil) - } - if result == nil { - return coreerr.E("journal.Append", "result is required", nil) - } - - entry := JournalEntry{ - Timestamp: result.Timestamp.UTC().Format("2006-01-02T15:04:05Z"), - Epic: signal.EpicNumber, - Child: signal.ChildNumber, - PR: signal.PRNumber, - Repo: signal.RepoFullName(), - Action: result.Action, - Signals: SignalSnapshot{ - PRState: signal.PRState, - IsDraft: signal.IsDraft, - CheckStatus: signal.CheckStatus, - Mergeable: signal.Mergeable, - ThreadsTotal: signal.ThreadsTotal, - ThreadsResolved: signal.ThreadsResolved, - }, - Result: ResultSnapshot{ - Success: result.Success, - Error: result.Error, - DurationMs: result.Duration.Milliseconds(), - }, - Cycle: result.Cycle, - } - - data, err := json.Marshal(entry) - if err != nil { - return coreerr.E("journal.Append", "marshal entry", err) - } - data = append(data, '\n') - - // Sanitize path components to prevent path traversal (CVE: issue #46). - owner, err := sanitizePathComponent(signal.RepoOwner) - if err != nil { - return coreerr.E("journal.Append", "invalid repo owner", err) - } - repo, err := sanitizePathComponent(signal.RepoName) - if err != nil { - return coreerr.E("journal.Append", "invalid repo name", err) - } - - date := result.Timestamp.UTC().Format("2006-01-02") - dir := filepath.Join(j.baseDir, owner, repo) - - // Resolve to absolute path and verify it stays within baseDir. - absBase, err := filepath.Abs(j.baseDir) - if err != nil { - return coreerr.E("journal.Append", "resolve base directory", err) - } - absDir, err := filepath.Abs(dir) - if err != nil { - return coreerr.E("journal.Append", "resolve journal directory", err) - } - if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) { - return coreerr.E("journal.Append", "path escapes base directory: "+absDir, nil) - } - - j.mu.Lock() - defer j.mu.Unlock() - - if err := coreio.Local.EnsureDir(dir); err != nil { - return coreerr.E("journal.Append", "create directory", err) - } - - path := filepath.Join(dir, date+".jsonl") - f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return coreerr.E("journal.Append", "open file", err) - } - defer func() { _ = f.Close() }() - - _, err = f.Write(data) - return err -} diff --git a/pkg/jobrunner/journal_replay_test.go b/pkg/jobrunner/journal_replay_test.go deleted file mode 100644 index 3617366..0000000 --- a/pkg/jobrunner/journal_replay_test.go +++ /dev/null @@ -1,540 +0,0 @@ -package jobrunner - -import ( - "bufio" - "encoding/json" - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// readJournalEntries reads all JSONL entries from a given file path. -func readJournalEntries(t *testing.T, path string) []JournalEntry { - t.Helper() - f, err := os.Open(path) - require.NoError(t, err) - defer func() { _ = f.Close() }() - - var entries []JournalEntry - scanner := bufio.NewScanner(f) - for scanner.Scan() { - var entry JournalEntry - err := json.Unmarshal(scanner.Bytes(), &entry) - require.NoError(t, err) - entries = append(entries, entry) - } - require.NoError(t, scanner.Err()) - return entries -} - -// readAllJournalFiles reads all .jsonl files recursively under a base directory. -func readAllJournalFiles(t *testing.T, baseDir string) []JournalEntry { - t.Helper() - var all []JournalEntry - err := filepath.Walk(baseDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - if filepath.Ext(path) == ".jsonl" { - entries := readJournalEntries(t, path) - all = append(all, entries...) - } - return nil - }) - require.NoError(t, err) - return all -} - -// --- Journal replay: write multiple entries, read back, verify round-trip --- - -func TestJournal_Replay_Good_WriteAndReadBack(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - baseTime := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC) - - // Write 5 entries with different actions, times, and repos. - entries := []struct { - signal *PipelineSignal - result *ActionResult - }{ - { - signal: &PipelineSignal{ - EpicNumber: 1, ChildNumber: 2, PRNumber: 10, - RepoOwner: "org-a", RepoName: "repo-1", - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - }, - result: &ActionResult{ - Action: "enable_auto_merge", - RepoOwner: "org-a", RepoName: "repo-1", - Success: true, Timestamp: baseTime, Duration: 100 * time.Millisecond, Cycle: 1, - }, - }, - { - signal: &PipelineSignal{ - EpicNumber: 1, ChildNumber: 3, PRNumber: 11, - RepoOwner: "org-a", RepoName: "repo-1", - PRState: "OPEN", CheckStatus: "FAILURE", Mergeable: "CONFLICTING", - }, - result: &ActionResult{ - Action: "send_fix_command", - RepoOwner: "org-a", RepoName: "repo-1", - Success: true, Timestamp: baseTime.Add(5 * time.Minute), Duration: 50 * time.Millisecond, Cycle: 1, - }, - }, - { - signal: &PipelineSignal{ - EpicNumber: 5, ChildNumber: 10, PRNumber: 20, - RepoOwner: "org-b", RepoName: "repo-2", - PRState: "MERGED", CheckStatus: "SUCCESS", Mergeable: "UNKNOWN", - }, - result: &ActionResult{ - Action: "tick_parent", - RepoOwner: "org-b", RepoName: "repo-2", - Success: true, Timestamp: baseTime.Add(10 * time.Minute), Duration: 200 * time.Millisecond, Cycle: 2, - }, - }, - { - signal: &PipelineSignal{ - EpicNumber: 5, ChildNumber: 11, PRNumber: 21, - RepoOwner: "org-b", RepoName: "repo-2", - PRState: "OPEN", CheckStatus: "PENDING", Mergeable: "MERGEABLE", - IsDraft: true, - }, - result: &ActionResult{ - Action: "publish_draft", - RepoOwner: "org-b", RepoName: "repo-2", - Success: false, Error: "API error", Timestamp: baseTime.Add(15 * time.Minute), - Duration: 300 * time.Millisecond, Cycle: 2, - }, - }, - { - signal: &PipelineSignal{ - EpicNumber: 1, ChildNumber: 4, PRNumber: 12, - RepoOwner: "org-a", RepoName: "repo-1", - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - ThreadsTotal: 3, ThreadsResolved: 1, - }, - result: &ActionResult{ - Action: "dismiss_reviews", - RepoOwner: "org-a", RepoName: "repo-1", - Success: true, Timestamp: baseTime.Add(20 * time.Minute), Duration: 150 * time.Millisecond, Cycle: 3, - }, - }, - } - - for _, e := range entries { - err := j.Append(e.signal, e.result) - require.NoError(t, err) - } - - // Read back all entries. - all := readAllJournalFiles(t, dir) - require.Len(t, all, 5) - - // Build a map by action for flexible lookup (filepath.Walk order is by path, not insertion). - byAction := make(map[string][]JournalEntry) - for _, e := range all { - byAction[e.Action] = append(byAction[e.Action], e) - } - - // Verify enable_auto_merge entry (org-a/repo-1). - require.Len(t, byAction["enable_auto_merge"], 1) - eam := byAction["enable_auto_merge"][0] - assert.Equal(t, "org-a/repo-1", eam.Repo) - assert.Equal(t, 1, eam.Epic) - assert.Equal(t, 2, eam.Child) - assert.Equal(t, 10, eam.PR) - assert.Equal(t, 1, eam.Cycle) - assert.True(t, eam.Result.Success) - assert.Equal(t, int64(100), eam.Result.DurationMs) - - // Verify publish_draft (failed entry has error). - require.Len(t, byAction["publish_draft"], 1) - pd := byAction["publish_draft"][0] - assert.Equal(t, "publish_draft", pd.Action) - assert.False(t, pd.Result.Success) - assert.Equal(t, "API error", pd.Result.Error) - - // Verify signal snapshot preserves state. - assert.True(t, pd.Signals.IsDraft) - assert.Equal(t, "PENDING", pd.Signals.CheckStatus) - - // Verify dismiss_reviews has thread counts preserved. - require.Len(t, byAction["dismiss_reviews"], 1) - dr := byAction["dismiss_reviews"][0] - assert.Equal(t, 3, dr.Signals.ThreadsTotal) - assert.Equal(t, 1, dr.Signals.ThreadsResolved) -} - -// --- Journal replay: filter by action --- - -func TestJournal_Replay_Good_FilterByAction(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC) - - actions := []string{"enable_auto_merge", "tick_parent", "send_fix_command", "tick_parent", "publish_draft"} - for i, action := range actions { - signal := &PipelineSignal{ - EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i, - RepoOwner: "org", RepoName: "repo", - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - } - result := &ActionResult{ - Action: action, - RepoOwner: "org", RepoName: "repo", - Success: true, - Timestamp: ts.Add(time.Duration(i) * time.Minute), - Duration: 100 * time.Millisecond, - Cycle: i + 1, - } - require.NoError(t, j.Append(signal, result)) - } - - all := readAllJournalFiles(t, dir) - require.Len(t, all, 5) - - // Filter by action=tick_parent. - var tickParentEntries []JournalEntry - for _, e := range all { - if e.Action == "tick_parent" { - tickParentEntries = append(tickParentEntries, e) - } - } - - assert.Len(t, tickParentEntries, 2) - assert.Equal(t, 2, tickParentEntries[0].Child) - assert.Equal(t, 4, tickParentEntries[1].Child) -} - -// --- Journal replay: filter by repo --- - -func TestJournal_Replay_Good_FilterByRepo(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC) - - repos := []struct { - owner string - name string - }{ - {"host-uk", "core-php"}, - {"host-uk", "core-tenant"}, - {"host-uk", "core-php"}, - {"lethean", "go-scm"}, - {"host-uk", "core-tenant"}, - } - - for i, r := range repos { - signal := &PipelineSignal{ - EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i, - RepoOwner: r.owner, RepoName: r.name, - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - } - result := &ActionResult{ - Action: "tick_parent", - RepoOwner: r.owner, RepoName: r.name, - Success: true, - Timestamp: ts.Add(time.Duration(i) * time.Minute), - Duration: 50 * time.Millisecond, - Cycle: i + 1, - } - require.NoError(t, j.Append(signal, result)) - } - - // Read entries for host-uk/core-php. - phpPath := filepath.Join(dir, "host-uk", "core-php", "2026-02-10.jsonl") - phpEntries := readJournalEntries(t, phpPath) - assert.Len(t, phpEntries, 2) - for _, e := range phpEntries { - assert.Equal(t, "host-uk/core-php", e.Repo) - } - - // Read entries for host-uk/core-tenant. - tenantPath := filepath.Join(dir, "host-uk", "core-tenant", "2026-02-10.jsonl") - tenantEntries := readJournalEntries(t, tenantPath) - assert.Len(t, tenantEntries, 2) - for _, e := range tenantEntries { - assert.Equal(t, "host-uk/core-tenant", e.Repo) - } - - // Read entries for lethean/go-scm. - scmPath := filepath.Join(dir, "lethean", "go-scm", "2026-02-10.jsonl") - scmEntries := readJournalEntries(t, scmPath) - assert.Len(t, scmEntries, 1) - assert.Equal(t, "lethean/go-scm", scmEntries[0].Repo) -} - -// --- Journal replay: filter by time range (date partitioning) --- - -func TestJournal_Replay_Good_FilterByTimeRange(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - // Write entries across three different days. - dates := []time.Time{ - time.Date(2026, 2, 8, 9, 0, 0, 0, time.UTC), - time.Date(2026, 2, 9, 10, 0, 0, 0, time.UTC), - time.Date(2026, 2, 9, 14, 0, 0, 0, time.UTC), - time.Date(2026, 2, 10, 8, 0, 0, 0, time.UTC), - time.Date(2026, 2, 10, 16, 0, 0, 0, time.UTC), - } - - for i, ts := range dates { - signal := &PipelineSignal{ - EpicNumber: 1, ChildNumber: i + 1, PRNumber: 10 + i, - RepoOwner: "org", RepoName: "repo", - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - } - result := &ActionResult{ - Action: "merge", - RepoOwner: "org", RepoName: "repo", - Success: true, - Timestamp: ts, - Duration: 100 * time.Millisecond, - Cycle: i + 1, - } - require.NoError(t, j.Append(signal, result)) - } - - // Verify each date file has the correct number of entries. - day8Path := filepath.Join(dir, "org", "repo", "2026-02-08.jsonl") - day8Entries := readJournalEntries(t, day8Path) - assert.Len(t, day8Entries, 1) - assert.Equal(t, "2026-02-08T09:00:00Z", day8Entries[0].Timestamp) - - day9Path := filepath.Join(dir, "org", "repo", "2026-02-09.jsonl") - day9Entries := readJournalEntries(t, day9Path) - assert.Len(t, day9Entries, 2) - assert.Equal(t, "2026-02-09T10:00:00Z", day9Entries[0].Timestamp) - assert.Equal(t, "2026-02-09T14:00:00Z", day9Entries[1].Timestamp) - - day10Path := filepath.Join(dir, "org", "repo", "2026-02-10.jsonl") - day10Entries := readJournalEntries(t, day10Path) - assert.Len(t, day10Entries, 2) - - // Simulate a time range query: get entries for Feb 9 only. - // In a real system, you'd list files matching the date range. - // Here we verify the date partitioning is correct. - rangeStart := time.Date(2026, 2, 9, 0, 0, 0, 0, time.UTC) - rangeEnd := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) // exclusive - - var filtered []JournalEntry - all := readAllJournalFiles(t, dir) - for _, e := range all { - ts, err := time.Parse("2006-01-02T15:04:05Z", e.Timestamp) - require.NoError(t, err) - if !ts.Before(rangeStart) && ts.Before(rangeEnd) { - filtered = append(filtered, e) - } - } - - assert.Len(t, filtered, 2) - assert.Equal(t, 2, filtered[0].Child) - assert.Equal(t, 3, filtered[1].Child) -} - -// --- Journal replay: combined filter (action + repo + time) --- - -func TestJournal_Replay_Good_CombinedFilter(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts1 := time.Date(2026, 2, 10, 10, 0, 0, 0, time.UTC) - ts2 := time.Date(2026, 2, 10, 11, 0, 0, 0, time.UTC) - ts3 := time.Date(2026, 2, 11, 9, 0, 0, 0, time.UTC) - - testData := []struct { - owner string - name string - action string - ts time.Time - }{ - {"org", "repo-a", "tick_parent", ts1}, - {"org", "repo-a", "enable_auto_merge", ts1}, - {"org", "repo-b", "tick_parent", ts2}, - {"org", "repo-a", "tick_parent", ts3}, - {"org", "repo-b", "send_fix_command", ts3}, - } - - for i, td := range testData { - signal := &PipelineSignal{ - EpicNumber: 1, ChildNumber: i + 1, PRNumber: 100 + i, - RepoOwner: td.owner, RepoName: td.name, - PRState: "MERGED", CheckStatus: "SUCCESS", Mergeable: "UNKNOWN", - } - result := &ActionResult{ - Action: td.action, - RepoOwner: td.owner, RepoName: td.name, - Success: true, - Timestamp: td.ts, - Duration: 50 * time.Millisecond, - Cycle: i + 1, - } - require.NoError(t, j.Append(signal, result)) - } - - // Filter: action=tick_parent AND repo=org/repo-a. - repoAPath := filepath.Join(dir, "org", "repo-a") - var repoAEntries []JournalEntry - err = filepath.Walk(repoAPath, func(path string, info os.FileInfo, walkErr error) error { - if walkErr != nil { - return walkErr - } - if filepath.Ext(path) == ".jsonl" { - entries := readJournalEntries(t, path) - repoAEntries = append(repoAEntries, entries...) - } - return nil - }) - require.NoError(t, err) - - var tickParentRepoA []JournalEntry - for _, e := range repoAEntries { - if e.Action == "tick_parent" && e.Repo == "org/repo-a" { - tickParentRepoA = append(tickParentRepoA, e) - } - } - - assert.Len(t, tickParentRepoA, 2) - assert.Equal(t, 1, tickParentRepoA[0].Child) - assert.Equal(t, 4, tickParentRepoA[1].Child) -} - -// --- Journal replay: empty journal returns no entries --- - -func TestJournal_Replay_Good_EmptyJournal(t *testing.T) { - dir := t.TempDir() - - all := readAllJournalFiles(t, dir) - assert.Empty(t, all) -} - -// --- Journal replay: single entry round-trip preserves all fields --- - -func TestJournal_Replay_Good_FullFieldRoundTrip(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 15, 14, 30, 45, 0, time.UTC) - - signal := &PipelineSignal{ - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - RepoOwner: "host-uk", - RepoName: "core-admin", - PRState: "OPEN", - IsDraft: true, - Mergeable: "CONFLICTING", - CheckStatus: "FAILURE", - ThreadsTotal: 5, - ThreadsResolved: 2, - } - - result := &ActionResult{ - Action: "send_fix_command", - RepoOwner: "host-uk", - RepoName: "core-admin", - Success: false, - Error: "comment API returned 503", - Timestamp: ts, - Duration: 1500 * time.Millisecond, - Cycle: 7, - } - - require.NoError(t, j.Append(signal, result)) - - path := filepath.Join(dir, "host-uk", "core-admin", "2026-02-15.jsonl") - entries := readJournalEntries(t, path) - require.Len(t, entries, 1) - - e := entries[0] - assert.Equal(t, "2026-02-15T14:30:45Z", e.Timestamp) - assert.Equal(t, 42, e.Epic) - assert.Equal(t, 7, e.Child) - assert.Equal(t, 99, e.PR) - assert.Equal(t, "host-uk/core-admin", e.Repo) - assert.Equal(t, "send_fix_command", e.Action) - assert.Equal(t, 7, e.Cycle) - - // Signal snapshot. - assert.Equal(t, "OPEN", e.Signals.PRState) - assert.True(t, e.Signals.IsDraft) - assert.Equal(t, "CONFLICTING", e.Signals.Mergeable) - assert.Equal(t, "FAILURE", e.Signals.CheckStatus) - assert.Equal(t, 5, e.Signals.ThreadsTotal) - assert.Equal(t, 2, e.Signals.ThreadsResolved) - - // Result snapshot. - assert.False(t, e.Result.Success) - assert.Equal(t, "comment API returned 503", e.Result.Error) - assert.Equal(t, int64(1500), e.Result.DurationMs) -} - -// --- Journal replay: concurrent writes produce valid JSONL --- - -func TestJournal_Replay_Good_ConcurrentWrites(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 10, 12, 0, 0, 0, time.UTC) - - // Write 20 entries concurrently. - done := make(chan struct{}, 20) - for i := range 20 { - go func(idx int) { - signal := &PipelineSignal{ - EpicNumber: 1, ChildNumber: idx, PRNumber: idx, - RepoOwner: "org", RepoName: "repo", - PRState: "OPEN", CheckStatus: "SUCCESS", Mergeable: "MERGEABLE", - } - result := &ActionResult{ - Action: "test", - RepoOwner: "org", RepoName: "repo", - Success: true, - Timestamp: ts, - Duration: 10 * time.Millisecond, - Cycle: idx, - } - _ = j.Append(signal, result) - done <- struct{}{} - }(i) - } - - for range 20 { - <-done - } - - // All entries should be parseable and present. - path := filepath.Join(dir, "org", "repo", "2026-02-10.jsonl") - entries := readJournalEntries(t, path) - assert.Len(t, entries, 20) - - // Each entry should have valid JSON (no corruption from concurrent writes). - for _, e := range entries { - assert.NotEmpty(t, e.Action) - assert.Equal(t, "org/repo", e.Repo) - } -} diff --git a/pkg/jobrunner/journal_test.go b/pkg/jobrunner/journal_test.go deleted file mode 100644 index a17a88b..0000000 --- a/pkg/jobrunner/journal_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package jobrunner - -import ( - "bufio" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestJournal_Append_Good(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC) - - signal := &PipelineSignal{ - EpicNumber: 10, - ChildNumber: 3, - PRNumber: 55, - RepoOwner: "host-uk", - RepoName: "core-tenant", - PRState: "OPEN", - IsDraft: false, - Mergeable: "MERGEABLE", - CheckStatus: "SUCCESS", - ThreadsTotal: 2, - ThreadsResolved: 1, - LastCommitSHA: "abc123", - LastCommitAt: ts, - LastReviewAt: ts, - } - - result := &ActionResult{ - Action: "merge", - RepoOwner: "host-uk", - RepoName: "core-tenant", - EpicNumber: 10, - ChildNumber: 3, - PRNumber: 55, - Success: true, - Timestamp: ts, - Duration: 1200 * time.Millisecond, - Cycle: 1, - } - - err = j.Append(signal, result) - require.NoError(t, err) - - // Read the file back. - expectedPath := filepath.Join(dir, "host-uk", "core-tenant", "2026-02-05.jsonl") - f, err := os.Open(expectedPath) - require.NoError(t, err) - defer func() { _ = f.Close() }() - - scanner := bufio.NewScanner(f) - require.True(t, scanner.Scan(), "expected at least one line in JSONL file") - - var entry JournalEntry - err = json.Unmarshal(scanner.Bytes(), &entry) - require.NoError(t, err) - - assert.Equal(t, "2026-02-05T14:30:00Z", entry.Timestamp) - assert.Equal(t, 10, entry.Epic) - assert.Equal(t, 3, entry.Child) - assert.Equal(t, 55, entry.PR) - assert.Equal(t, "host-uk/core-tenant", entry.Repo) - assert.Equal(t, "merge", entry.Action) - assert.Equal(t, 1, entry.Cycle) - - // Verify signal snapshot. - assert.Equal(t, "OPEN", entry.Signals.PRState) - assert.Equal(t, false, entry.Signals.IsDraft) - assert.Equal(t, "SUCCESS", entry.Signals.CheckStatus) - assert.Equal(t, "MERGEABLE", entry.Signals.Mergeable) - assert.Equal(t, 2, entry.Signals.ThreadsTotal) - assert.Equal(t, 1, entry.Signals.ThreadsResolved) - - // Verify result snapshot. - assert.Equal(t, true, entry.Result.Success) - assert.Equal(t, "", entry.Result.Error) - assert.Equal(t, int64(1200), entry.Result.DurationMs) - - // Append a second entry and verify two lines exist. - result2 := &ActionResult{ - Action: "comment", - RepoOwner: "host-uk", - RepoName: "core-tenant", - Success: false, - Error: "rate limited", - Timestamp: ts, - Duration: 50 * time.Millisecond, - Cycle: 2, - } - err = j.Append(signal, result2) - require.NoError(t, err) - - data, err := os.ReadFile(expectedPath) - require.NoError(t, err) - - lines := 0 - sc := bufio.NewScanner(strings.NewReader(string(data))) - for sc.Scan() { - lines++ - } - assert.Equal(t, 2, lines, "expected two JSONL lines after two appends") -} - -func TestJournal_Append_Bad_PathTraversal(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Now() - - tests := []struct { - name string - repoOwner string - repoName string - wantErr string - }{ - { - name: "dotdot owner", - repoOwner: "..", - repoName: "core", - wantErr: "invalid repo owner", - }, - { - name: "dotdot repo", - repoOwner: "host-uk", - repoName: "../../etc/cron.d", - wantErr: "invalid repo name", - }, - { - name: "slash in owner", - repoOwner: "../etc", - repoName: "core", - wantErr: "invalid repo owner", - }, - { - name: "absolute path in repo", - repoOwner: "host-uk", - repoName: "/etc/passwd", - wantErr: "invalid repo name", - }, - { - name: "empty owner", - repoOwner: "", - repoName: "core", - wantErr: "invalid repo owner", - }, - { - name: "empty repo", - repoOwner: "host-uk", - repoName: "", - wantErr: "invalid repo name", - }, - { - name: "dot only owner", - repoOwner: ".", - repoName: "core", - wantErr: "invalid repo owner", - }, - { - name: "spaces only owner", - repoOwner: " ", - repoName: "core", - wantErr: "invalid repo owner", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - signal := &PipelineSignal{ - RepoOwner: tc.repoOwner, - RepoName: tc.repoName, - } - result := &ActionResult{ - Action: "merge", - Timestamp: ts, - } - - err := j.Append(signal, result) - require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - }) - } -} - -func TestJournal_Append_Good_ValidNames(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - ts := time.Date(2026, 2, 5, 14, 30, 0, 0, time.UTC) - - // Verify valid names with dots, hyphens, underscores all work. - validNames := []struct { - owner string - repo string - }{ - {"host-uk", "core"}, - {"my_org", "my_repo"}, - {"org.name", "repo.v2"}, - {"a", "b"}, - {"Org-123", "Repo_456.go"}, - } - - for _, vn := range validNames { - signal := &PipelineSignal{ - RepoOwner: vn.owner, - RepoName: vn.repo, - } - result := &ActionResult{ - Action: "test", - Timestamp: ts, - } - - err := j.Append(signal, result) - assert.NoError(t, err, "expected valid name pair %s/%s to succeed", vn.owner, vn.repo) - } -} - -func TestJournal_Append_Bad_NilSignal(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - result := &ActionResult{ - Action: "merge", - Timestamp: time.Now(), - } - - err = j.Append(nil, result) - require.Error(t, err) - assert.Contains(t, err.Error(), "signal is required") -} - -func TestJournal_Append_Bad_NilResult(t *testing.T) { - dir := t.TempDir() - - j, err := NewJournal(dir) - require.NoError(t, err) - - signal := &PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - } - - err = j.Append(signal, nil) - require.Error(t, err) - assert.Contains(t, err.Error(), "result is required") -} diff --git a/pkg/jobrunner/poller.go b/pkg/jobrunner/poller.go deleted file mode 100644 index 58abec6..0000000 --- a/pkg/jobrunner/poller.go +++ /dev/null @@ -1,224 +0,0 @@ -package jobrunner - -import ( - "context" - "iter" - "sync" - "time" - - "forge.lthn.ai/core/go-log" -) - -// PollerConfig configures a Poller. -type PollerConfig struct { - Sources []JobSource - Handlers []JobHandler - Journal *Journal - PollInterval time.Duration - DryRun bool -} - -// Poller discovers signals from sources and dispatches them to handlers. -type Poller struct { - mu sync.RWMutex - sources []JobSource - handlers []JobHandler - journal *Journal - interval time.Duration - dryRun bool - cycle int -} - -// NewPoller creates a Poller from the given config. -func NewPoller(cfg PollerConfig) *Poller { - interval := cfg.PollInterval - if interval <= 0 { - interval = 60 * time.Second - } - - return &Poller{ - sources: cfg.Sources, - handlers: cfg.Handlers, - journal: cfg.Journal, - interval: interval, - dryRun: cfg.DryRun, - } -} - -// Cycle returns the number of completed poll-dispatch cycles. -func (p *Poller) Cycle() int { - p.mu.RLock() - defer p.mu.RUnlock() - return p.cycle -} - -// Sources returns an iterator over the poller's sources. -func (p *Poller) Sources() iter.Seq[JobSource] { - return func(yield func(JobSource) bool) { - p.mu.RLock() - sources := make([]JobSource, len(p.sources)) - copy(sources, p.sources) - p.mu.RUnlock() - - for _, s := range sources { - if !yield(s) { - return - } - } - } -} - -// Handlers returns an iterator over the poller's handlers. -func (p *Poller) Handlers() iter.Seq[JobHandler] { - return func(yield func(JobHandler) bool) { - p.mu.RLock() - handlers := make([]JobHandler, len(p.handlers)) - copy(handlers, p.handlers) - p.mu.RUnlock() - - for _, h := range handlers { - if !yield(h) { - return - } - } - } -} - -// DryRun returns whether dry-run mode is enabled. -func (p *Poller) DryRun() bool { - p.mu.RLock() - defer p.mu.RUnlock() - return p.dryRun -} - -// SetDryRun enables or disables dry-run mode. -func (p *Poller) SetDryRun(v bool) { - p.mu.Lock() - p.dryRun = v - p.mu.Unlock() -} - -// AddSource appends a source to the poller. -func (p *Poller) AddSource(s JobSource) { - p.mu.Lock() - p.sources = append(p.sources, s) - p.mu.Unlock() -} - -// AddHandler appends a handler to the poller. -func (p *Poller) AddHandler(h JobHandler) { - p.mu.Lock() - p.handlers = append(p.handlers, h) - p.mu.Unlock() -} - -// Run starts a blocking poll-dispatch loop. It runs one cycle immediately, -// then repeats on each tick of the configured interval until the context -// is cancelled. -func (p *Poller) Run(ctx context.Context) error { - if err := p.RunOnce(ctx); err != nil { - return err - } - - ticker := time.NewTicker(p.interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - if err := p.RunOnce(ctx); err != nil { - return err - } - } - } -} - -// RunOnce performs a single poll-dispatch cycle: iterate sources, poll each, -// find the first matching handler for each signal, and execute it. -func (p *Poller) RunOnce(ctx context.Context) error { - p.mu.Lock() - p.cycle++ - cycle := p.cycle - dryRun := p.dryRun - p.mu.Unlock() - - log.Info("poller cycle starting", "cycle", cycle) - - for src := range p.Sources() { - signals, err := src.Poll(ctx) - if err != nil { - log.Error("poll failed", "source", src.Name(), "err", err) - continue - } - - log.Info("polled source", "source", src.Name(), "signals", len(signals)) - - for _, sig := range signals { - handler := p.findHandler(p.Handlers(), sig) - if handler == nil { - log.Debug("no matching handler", "epic", sig.EpicNumber, "child", sig.ChildNumber) - continue - } - - if dryRun { - log.Info("dry-run: would execute", - "handler", handler.Name(), - "epic", sig.EpicNumber, - "child", sig.ChildNumber, - "pr", sig.PRNumber, - ) - continue - } - - start := time.Now() - result, err := handler.Execute(ctx, sig) - elapsed := time.Since(start) - - if err != nil { - log.Error("handler execution failed", - "handler", handler.Name(), - "epic", sig.EpicNumber, - "child", sig.ChildNumber, - "err", err, - ) - continue - } - - result.Cycle = cycle - result.EpicNumber = sig.EpicNumber - result.ChildNumber = sig.ChildNumber - result.Duration = elapsed - - if p.journal != nil { - if jErr := p.journal.Append(sig, result); jErr != nil { - log.Error("journal append failed", "err", jErr) - } - } - - if rErr := src.Report(ctx, result); rErr != nil { - log.Error("source report failed", "source", src.Name(), "err", rErr) - } - - log.Info("handler executed", - "handler", handler.Name(), - "action", result.Action, - "success", result.Success, - "duration", elapsed, - ) - } - } - - return nil -} - -// findHandler returns the first handler that matches the signal, or nil. -func (p *Poller) findHandler(handlers iter.Seq[JobHandler], sig *PipelineSignal) JobHandler { - for h := range handlers { - if h.Match(sig) { - return h - } - } - return nil -} diff --git a/pkg/jobrunner/poller_test.go b/pkg/jobrunner/poller_test.go deleted file mode 100644 index 1d3a908..0000000 --- a/pkg/jobrunner/poller_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package jobrunner - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Mock source --- - -type mockSource struct { - name string - signals []*PipelineSignal - reports []*ActionResult - mu sync.Mutex -} - -func (m *mockSource) Name() string { return m.name } - -func (m *mockSource) Poll(_ context.Context) ([]*PipelineSignal, error) { - m.mu.Lock() - defer m.mu.Unlock() - return m.signals, nil -} - -func (m *mockSource) Report(_ context.Context, result *ActionResult) error { - m.mu.Lock() - defer m.mu.Unlock() - m.reports = append(m.reports, result) - return nil -} - -// --- Mock handler --- - -type mockHandler struct { - name string - matchFn func(*PipelineSignal) bool - executed []*PipelineSignal - mu sync.Mutex -} - -func (m *mockHandler) Name() string { return m.name } - -func (m *mockHandler) Match(sig *PipelineSignal) bool { - if m.matchFn != nil { - return m.matchFn(sig) - } - return true -} - -func (m *mockHandler) Execute(_ context.Context, sig *PipelineSignal) (*ActionResult, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.executed = append(m.executed, sig) - return &ActionResult{ - Action: m.name, - RepoOwner: sig.RepoOwner, - RepoName: sig.RepoName, - PRNumber: sig.PRNumber, - Success: true, - Timestamp: time.Now(), - }, nil -} - -func TestPoller_RunOnce_Good(t *testing.T) { - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 2, - PRNumber: 10, - RepoOwner: "host-uk", - RepoName: "core-php", - PRState: "OPEN", - CheckStatus: "SUCCESS", - Mergeable: "MERGEABLE", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "test-handler", - matchFn: func(s *PipelineSignal) bool { - return s.PRNumber == 10 - }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - // Handler should have been called with our signal. - handler.mu.Lock() - defer handler.mu.Unlock() - require.Len(t, handler.executed, 1) - assert.Equal(t, 10, handler.executed[0].PRNumber) - - // Source should have received a report. - src.mu.Lock() - defer src.mu.Unlock() - require.Len(t, src.reports, 1) - assert.Equal(t, "test-handler", src.reports[0].Action) - assert.True(t, src.reports[0].Success) - assert.Equal(t, 1, src.reports[0].Cycle) - assert.Equal(t, 1, src.reports[0].EpicNumber) - assert.Equal(t, 2, src.reports[0].ChildNumber) - - // Cycle counter should have incremented. - assert.Equal(t, 1, p.Cycle()) -} - -func TestPoller_RunOnce_Good_NoSignals(t *testing.T) { - src := &mockSource{ - name: "empty-source", - signals: nil, - } - - handler := &mockHandler{ - name: "unused-handler", - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - // Handler should not have been called. - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Empty(t, handler.executed) - - // Source should not have received reports. - src.mu.Lock() - defer src.mu.Unlock() - assert.Empty(t, src.reports) - - assert.Equal(t, 1, p.Cycle()) -} - -func TestPoller_RunOnce_Good_NoMatchingHandler(t *testing.T) { - sig := &PipelineSignal{ - EpicNumber: 5, - ChildNumber: 8, - PRNumber: 42, - RepoOwner: "host-uk", - RepoName: "core-tenant", - PRState: "OPEN", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "picky-handler", - matchFn: func(s *PipelineSignal) bool { - return false // never matches - }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - }) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - // Handler should not have been called. - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Empty(t, handler.executed) - - // Source should not have received reports (no action taken). - src.mu.Lock() - defer src.mu.Unlock() - assert.Empty(t, src.reports) -} - -func TestPoller_RunOnce_Good_DryRun(t *testing.T) { - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 3, - PRNumber: 20, - RepoOwner: "host-uk", - RepoName: "core-admin", - PRState: "OPEN", - CheckStatus: "SUCCESS", - Mergeable: "MERGEABLE", - } - - src := &mockSource{ - name: "test-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "merge-handler", - matchFn: func(s *PipelineSignal) bool { - return true - }, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - Handlers: []JobHandler{handler}, - DryRun: true, - }) - - assert.True(t, p.DryRun()) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - // Handler should NOT have been called in dry-run mode. - handler.mu.Lock() - defer handler.mu.Unlock() - assert.Empty(t, handler.executed) - - // Source should not have received reports. - src.mu.Lock() - defer src.mu.Unlock() - assert.Empty(t, src.reports) -} - -func TestPoller_SetDryRun_Good(t *testing.T) { - p := NewPoller(PollerConfig{}) - - assert.False(t, p.DryRun()) - p.SetDryRun(true) - assert.True(t, p.DryRun()) - p.SetDryRun(false) - assert.False(t, p.DryRun()) -} - -func TestPoller_AddSourceAndHandler_Good(t *testing.T) { - p := NewPoller(PollerConfig{}) - - sig := &PipelineSignal{ - EpicNumber: 1, - ChildNumber: 1, - PRNumber: 5, - RepoOwner: "host-uk", - RepoName: "core-php", - PRState: "OPEN", - } - - src := &mockSource{ - name: "added-source", - signals: []*PipelineSignal{sig}, - } - - handler := &mockHandler{ - name: "added-handler", - matchFn: func(s *PipelineSignal) bool { return true }, - } - - p.AddSource(src) - p.AddHandler(handler) - - err := p.RunOnce(context.Background()) - require.NoError(t, err) - - handler.mu.Lock() - defer handler.mu.Unlock() - require.Len(t, handler.executed, 1) - assert.Equal(t, 5, handler.executed[0].PRNumber) -} - -func TestPoller_Run_Good(t *testing.T) { - src := &mockSource{ - name: "tick-source", - signals: nil, - } - - p := NewPoller(PollerConfig{ - Sources: []JobSource{src}, - PollInterval: 50 * time.Millisecond, - }) - - ctx, cancel := context.WithTimeout(context.Background(), 180*time.Millisecond) - defer cancel() - - err := p.Run(ctx) - assert.ErrorIs(t, err, context.DeadlineExceeded) - - // Should have completed at least 2 cycles (one immediate + at least one tick). - assert.GreaterOrEqual(t, p.Cycle(), 2) -} - -func TestPoller_DefaultInterval_Good(t *testing.T) { - p := NewPoller(PollerConfig{}) - assert.Equal(t, 60*time.Second, p.interval) -} diff --git a/pkg/jobrunner/types.go b/pkg/jobrunner/types.go deleted file mode 100644 index ce51caf..0000000 --- a/pkg/jobrunner/types.go +++ /dev/null @@ -1,72 +0,0 @@ -package jobrunner - -import ( - "context" - "time" -) - -// PipelineSignal is the structural snapshot of a child issue/PR. -// Carries structural state plus issue title/body for dispatch prompts. -type PipelineSignal struct { - EpicNumber int - ChildNumber int - PRNumber int - RepoOwner string - RepoName string - PRState string // OPEN, MERGED, CLOSED - IsDraft bool - Mergeable string // MERGEABLE, CONFLICTING, UNKNOWN - CheckStatus string // SUCCESS, FAILURE, PENDING - ThreadsTotal int - ThreadsResolved int - LastCommitSHA string - LastCommitAt time.Time - LastReviewAt time.Time - NeedsCoding bool // true if child has no PR (work not started) - Assignee string // issue assignee username (for dispatch) - IssueTitle string // child issue title (for dispatch prompt) - IssueBody string // child issue body (for dispatch prompt) - Type string // signal type (e.g., "agent_completion") - Success bool // agent completion success flag - Error string // agent error message - Message string // agent completion message -} - -// RepoFullName returns "owner/repo". -func (s *PipelineSignal) RepoFullName() string { - return s.RepoOwner + "/" + s.RepoName -} - -// HasUnresolvedThreads returns true if there are unresolved review threads. -func (s *PipelineSignal) HasUnresolvedThreads() bool { - return s.ThreadsTotal > s.ThreadsResolved -} - -// ActionResult carries the outcome of a handler execution. -type ActionResult struct { - Action string `json:"action"` - RepoOwner string `json:"repo_owner"` - RepoName string `json:"repo_name"` - EpicNumber int `json:"epic"` - ChildNumber int `json:"child"` - PRNumber int `json:"pr"` - Success bool `json:"success"` - Error string `json:"error,omitempty"` - Timestamp time.Time `json:"ts"` - Duration time.Duration `json:"duration_ms"` - Cycle int `json:"cycle"` -} - -// JobSource discovers actionable work from an external system. -type JobSource interface { - Name() string - Poll(ctx context.Context) ([]*PipelineSignal, error) - Report(ctx context.Context, result *ActionResult) error -} - -// JobHandler processes a single pipeline signal. -type JobHandler interface { - Name() string - Match(signal *PipelineSignal) bool - Execute(ctx context.Context, signal *PipelineSignal) (*ActionResult, error) -} diff --git a/pkg/jobrunner/types_test.go b/pkg/jobrunner/types_test.go deleted file mode 100644 index c81a840..0000000 --- a/pkg/jobrunner/types_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package jobrunner - -import ( - "encoding/json" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestPipelineSignal_RepoFullName_Good(t *testing.T) { - sig := &PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - } - assert.Equal(t, "host-uk/core-php", sig.RepoFullName()) -} - -func TestPipelineSignal_HasUnresolvedThreads_Good(t *testing.T) { - sig := &PipelineSignal{ - ThreadsTotal: 5, - ThreadsResolved: 3, - } - assert.True(t, sig.HasUnresolvedThreads()) -} - -func TestPipelineSignal_HasUnresolvedThreads_Bad_AllResolved(t *testing.T) { - sig := &PipelineSignal{ - ThreadsTotal: 4, - ThreadsResolved: 4, - } - assert.False(t, sig.HasUnresolvedThreads()) - - // Also verify zero threads is not unresolved. - sigZero := &PipelineSignal{ - ThreadsTotal: 0, - ThreadsResolved: 0, - } - assert.False(t, sigZero.HasUnresolvedThreads()) -} - -func TestActionResult_JSON_Good(t *testing.T) { - ts := time.Date(2026, 2, 5, 12, 0, 0, 0, time.UTC) - result := &ActionResult{ - Action: "merge", - RepoOwner: "host-uk", - RepoName: "core-tenant", - EpicNumber: 42, - ChildNumber: 7, - PRNumber: 99, - Success: true, - Timestamp: ts, - Duration: 1500 * time.Millisecond, - Cycle: 3, - } - - data, err := json.Marshal(result) - require.NoError(t, err) - - var decoded map[string]any - err = json.Unmarshal(data, &decoded) - require.NoError(t, err) - - assert.Equal(t, "merge", decoded["action"]) - assert.Equal(t, "host-uk", decoded["repo_owner"]) - assert.Equal(t, "core-tenant", decoded["repo_name"]) - assert.Equal(t, float64(42), decoded["epic"]) - assert.Equal(t, float64(7), decoded["child"]) - assert.Equal(t, float64(99), decoded["pr"]) - assert.Equal(t, true, decoded["success"]) - assert.Equal(t, float64(3), decoded["cycle"]) - - // Error field should be omitted when empty. - _, hasError := decoded["error"] - assert.False(t, hasError, "error field should be omitted when empty") - - // Verify round-trip with error field present. - resultWithErr := &ActionResult{ - Action: "merge", - RepoOwner: "host-uk", - RepoName: "core-tenant", - Success: false, - Error: "checks failing", - Timestamp: ts, - Duration: 200 * time.Millisecond, - Cycle: 1, - } - data2, err := json.Marshal(resultWithErr) - require.NoError(t, err) - - var decoded2 map[string]any - err = json.Unmarshal(data2, &decoded2) - require.NoError(t, err) - - assert.Equal(t, "checks failing", decoded2["error"]) - assert.Equal(t, false, decoded2["success"]) -} diff --git a/pkg/lifecycle/allowance.go b/pkg/lifecycle/allowance.go deleted file mode 100644 index 8310c55..0000000 --- a/pkg/lifecycle/allowance.go +++ /dev/null @@ -1,335 +0,0 @@ -package lifecycle - -import ( - "iter" - "sync" - "time" -) - -// AllowanceStatus indicates the current state of an agent's quota. -type AllowanceStatus string - -const ( - // AllowanceOK indicates the agent has remaining quota. - AllowanceOK AllowanceStatus = "ok" - // AllowanceWarning indicates the agent is at 80%+ usage. - AllowanceWarning AllowanceStatus = "warning" - // AllowanceExceeded indicates the agent has exceeded its quota. - AllowanceExceeded AllowanceStatus = "exceeded" -) - -// AgentAllowance defines the quota limits for a single agent. -type AgentAllowance struct { - // AgentID is the unique identifier for the agent. - AgentID string `json:"agent_id" yaml:"agent_id"` - // DailyTokenLimit is the maximum tokens (in+out) per 24h. 0 means unlimited. - DailyTokenLimit int64 `json:"daily_token_limit" yaml:"daily_token_limit"` - // DailyJobLimit is the maximum jobs per 24h. 0 means unlimited. - DailyJobLimit int `json:"daily_job_limit" yaml:"daily_job_limit"` - // ConcurrentJobs is the maximum simultaneous jobs. 0 means unlimited. - ConcurrentJobs int `json:"concurrent_jobs" yaml:"concurrent_jobs"` - // MaxJobDuration is the maximum job duration before kill. 0 means unlimited. - MaxJobDuration time.Duration `json:"max_job_duration" yaml:"max_job_duration"` - // ModelAllowlist restricts which models this agent can use. Empty means all. - ModelAllowlist []string `json:"model_allowlist,omitempty" yaml:"model_allowlist"` -} - -// ModelQuota defines global per-model limits across all agents. -type ModelQuota struct { - // Model is the model identifier (e.g. "claude-sonnet-4-5-20250929"). - Model string `json:"model" yaml:"model"` - // DailyTokenBudget is the total tokens across all agents per 24h. - DailyTokenBudget int64 `json:"daily_token_budget" yaml:"daily_token_budget"` - // HourlyRateLimit is the max requests per hour. - // Reserved: stored but not yet enforced in AllowanceService.Check. - // Enforcement requires AllowanceStore.GetHourlyUsage (sliding window). - HourlyRateLimit int `json:"hourly_rate_limit" yaml:"hourly_rate_limit"` - // CostCeiling stops all usage if cumulative cost exceeds this (in cents). - // Reserved: stored but not yet enforced in AllowanceService.Check. - CostCeiling int64 `json:"cost_ceiling" yaml:"cost_ceiling"` -} - -// RepoLimit defines per-repository rate limits. -type RepoLimit struct { - // Repo is the repository identifier (e.g. "owner/repo"). - Repo string `json:"repo" yaml:"repo"` - // MaxDailyPRs is the maximum PRs per day. 0 means unlimited. - MaxDailyPRs int `json:"max_daily_prs" yaml:"max_daily_prs"` - // MaxDailyIssues is the maximum issues per day. 0 means unlimited. - MaxDailyIssues int `json:"max_daily_issues" yaml:"max_daily_issues"` - // CooldownAfterFailure is the wait time after a failure before retrying. - CooldownAfterFailure time.Duration `json:"cooldown_after_failure" yaml:"cooldown_after_failure"` -} - -// UsageRecord tracks an agent's current usage within a quota period. -type UsageRecord struct { - // AgentID is the agent this record belongs to. - AgentID string `json:"agent_id"` - // TokensUsed is the total tokens consumed in the current period. - TokensUsed int64 `json:"tokens_used"` - // JobsStarted is the total jobs started in the current period. - JobsStarted int `json:"jobs_started"` - // ActiveJobs is the number of currently running jobs. - ActiveJobs int `json:"active_jobs"` - // PeriodStart is when the current quota period began. - PeriodStart time.Time `json:"period_start"` -} - -// QuotaCheckResult is the outcome of a pre-dispatch allowance check. -type QuotaCheckResult struct { - // Allowed indicates whether the agent may proceed. - Allowed bool `json:"allowed"` - // Status is the current allowance state. - Status AllowanceStatus `json:"status"` - // Remaining is the number of tokens remaining in the period. - RemainingTokens int64 `json:"remaining_tokens"` - // RemainingJobs is the number of jobs remaining in the period. - RemainingJobs int `json:"remaining_jobs"` - // Reason explains why the check failed (if !Allowed). - Reason string `json:"reason,omitempty"` -} - -// QuotaEvent represents a change in quota usage, used for recovery. -type QuotaEvent string - -const ( - // QuotaEventJobStarted deducts quota when a job begins. - QuotaEventJobStarted QuotaEvent = "job_started" - // QuotaEventJobCompleted deducts nothing (already counted). - QuotaEventJobCompleted QuotaEvent = "job_completed" - // QuotaEventJobFailed returns 50% of token quota. - QuotaEventJobFailed QuotaEvent = "job_failed" - // QuotaEventJobCancelled returns 100% of token quota. - QuotaEventJobCancelled QuotaEvent = "job_cancelled" -) - -// UsageReport is emitted by the agent runner to report token consumption. -type UsageReport struct { - // AgentID is the agent that consumed tokens. - AgentID string `json:"agent_id"` - // JobID identifies the specific job. - JobID string `json:"job_id"` - // Model is the model used. - Model string `json:"model"` - // TokensIn is the number of input tokens consumed. - TokensIn int64 `json:"tokens_in"` - // TokensOut is the number of output tokens consumed. - TokensOut int64 `json:"tokens_out"` - // Event is the type of quota event. - Event QuotaEvent `json:"event"` - // Timestamp is when the usage occurred. - Timestamp time.Time `json:"timestamp"` -} - -// AllowanceStore is the interface for persisting and querying allowance data. -// Implementations may use Redis, SQLite, or any backing store. -type AllowanceStore interface { - // GetAllowance returns the quota limits for an agent. - GetAllowance(agentID string) (*AgentAllowance, error) - // SetAllowance persists quota limits for an agent. - SetAllowance(a *AgentAllowance) error - // Allowances returns an iterator over all agent allowances. - Allowances() iter.Seq[*AgentAllowance] - // GetUsage returns the current usage record for an agent. - GetUsage(agentID string) (*UsageRecord, error) - // Usages returns an iterator over all usage records. - Usages() iter.Seq[*UsageRecord] - // IncrementUsage atomically adds to an agent's usage counters. - IncrementUsage(agentID string, tokens int64, jobs int) error - // DecrementActiveJobs reduces the active job count by 1. - DecrementActiveJobs(agentID string) error - // ReturnTokens adds tokens back to the agent's remaining quota. - ReturnTokens(agentID string, tokens int64) error - // ResetUsage clears usage counters for an agent (daily reset). - ResetUsage(agentID string) error - // GetModelQuota returns global limits for a model. - GetModelQuota(model string) (*ModelQuota, error) - // GetModelUsage returns current token usage for a model. - GetModelUsage(model string) (int64, error) - // IncrementModelUsage atomically adds to a model's usage counter. - IncrementModelUsage(model string, tokens int64) error -} - -// MemoryStore is an in-memory AllowanceStore for testing and single-node use. -type MemoryStore struct { - mu sync.RWMutex - allowances map[string]*AgentAllowance - usage map[string]*UsageRecord - modelQuotas map[string]*ModelQuota - modelUsage map[string]int64 -} - -// NewMemoryStore creates a new in-memory allowance store. -func NewMemoryStore() *MemoryStore { - return &MemoryStore{ - allowances: make(map[string]*AgentAllowance), - usage: make(map[string]*UsageRecord), - modelQuotas: make(map[string]*ModelQuota), - modelUsage: make(map[string]int64), - } -} - -// GetAllowance returns the quota limits for an agent. -func (m *MemoryStore) GetAllowance(agentID string) (*AgentAllowance, error) { - m.mu.RLock() - defer m.mu.RUnlock() - a, ok := m.allowances[agentID] - if !ok { - return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID} - } - cp := *a - return &cp, nil -} - -// SetAllowance persists quota limits for an agent. -func (m *MemoryStore) SetAllowance(a *AgentAllowance) error { - m.mu.Lock() - defer m.mu.Unlock() - cp := *a - m.allowances[a.AgentID] = &cp - return nil -} - -// Allowances returns an iterator over all agent allowances. -func (m *MemoryStore) Allowances() iter.Seq[*AgentAllowance] { - return func(yield func(*AgentAllowance) bool) { - m.mu.RLock() - defer m.mu.RUnlock() - for _, a := range m.allowances { - cp := *a - if !yield(&cp) { - return - } - } - } -} - -// GetUsage returns the current usage record for an agent. -func (m *MemoryStore) GetUsage(agentID string) (*UsageRecord, error) { - m.mu.RLock() - defer m.mu.RUnlock() - u, ok := m.usage[agentID] - if !ok { - return &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - }, nil - } - cp := *u - return &cp, nil -} - -// Usages returns an iterator over all usage records. -func (m *MemoryStore) Usages() iter.Seq[*UsageRecord] { - return func(yield func(*UsageRecord) bool) { - m.mu.RLock() - defer m.mu.RUnlock() - for _, u := range m.usage { - cp := *u - if !yield(&cp) { - return - } - } - } -} - -// IncrementUsage atomically adds to an agent's usage counters. -func (m *MemoryStore) IncrementUsage(agentID string, tokens int64, jobs int) error { - m.mu.Lock() - defer m.mu.Unlock() - u, ok := m.usage[agentID] - if !ok { - u = &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - } - m.usage[agentID] = u - } - u.TokensUsed += tokens - u.JobsStarted += jobs - if jobs > 0 { - u.ActiveJobs += jobs - } - return nil -} - -// DecrementActiveJobs reduces the active job count by 1. -func (m *MemoryStore) DecrementActiveJobs(agentID string) error { - m.mu.Lock() - defer m.mu.Unlock() - u, ok := m.usage[agentID] - if !ok { - return nil - } - if u.ActiveJobs > 0 { - u.ActiveJobs-- - } - return nil -} - -// ReturnTokens adds tokens back to the agent's remaining quota. -func (m *MemoryStore) ReturnTokens(agentID string, tokens int64) error { - m.mu.Lock() - defer m.mu.Unlock() - u, ok := m.usage[agentID] - if !ok { - return nil - } - u.TokensUsed -= tokens - if u.TokensUsed < 0 { - u.TokensUsed = 0 - } - return nil -} - -// ResetUsage clears usage counters for an agent. -func (m *MemoryStore) ResetUsage(agentID string) error { - m.mu.Lock() - defer m.mu.Unlock() - m.usage[agentID] = &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - } - return nil -} - -// GetModelQuota returns global limits for a model. -func (m *MemoryStore) GetModelQuota(model string) (*ModelQuota, error) { - m.mu.RLock() - defer m.mu.RUnlock() - q, ok := m.modelQuotas[model] - if !ok { - return nil, &APIError{Code: 404, Message: "model quota not found: " + model} - } - cp := *q - return &cp, nil -} - -// GetModelUsage returns current token usage for a model. -func (m *MemoryStore) GetModelUsage(model string) (int64, error) { - m.mu.RLock() - defer m.mu.RUnlock() - return m.modelUsage[model], nil -} - -// IncrementModelUsage atomically adds to a model's usage counter. -func (m *MemoryStore) IncrementModelUsage(model string, tokens int64) error { - m.mu.Lock() - defer m.mu.Unlock() - m.modelUsage[model] += tokens - return nil -} - -// SetModelQuota sets global limits for a model (used in testing). -func (m *MemoryStore) SetModelQuota(q *ModelQuota) { - m.mu.Lock() - defer m.mu.Unlock() - cp := *q - m.modelQuotas[q.Model] = &cp -} - -// startOfDay returns midnight UTC for the given time. -func startOfDay(t time.Time) time.Time { - y, mo, d := t.Date() - return time.Date(y, mo, d, 0, 0, 0, 0, time.UTC) -} diff --git a/pkg/lifecycle/allowance_edge_test.go b/pkg/lifecycle/allowance_edge_test.go deleted file mode 100644 index 8374d02..0000000 --- a/pkg/lifecycle/allowance_edge_test.go +++ /dev/null @@ -1,662 +0,0 @@ -package lifecycle - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Allowance exhaustion edge cases --- - -func TestAllowanceExhaustion_ExactlyAtTokenLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - DailyTokenLimit: 10000, - }) - // Use exactly the limit. - _ = store.IncrementUsage("edge-agent", 10000, 0) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied at exactly the limit") - assert.Equal(t, AllowanceExceeded, result.Status) - assert.Equal(t, int64(0), result.RemainingTokens) - assert.Contains(t, result.Reason, "daily token limit exceeded") -} - -func TestAllowanceExhaustion_OneOverTokenLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - DailyTokenLimit: 10000, - }) - _ = store.IncrementUsage("edge-agent", 10001, 0) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Equal(t, AllowanceExceeded, result.Status) - assert.True(t, result.RemainingTokens < 0, "remaining should be negative") -} - -func TestAllowanceExhaustion_OneUnderTokenLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - DailyTokenLimit: 10000, - }) - _ = store.IncrementUsage("edge-agent", 9999, 0) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed with 1 token remaining") - assert.Equal(t, AllowanceWarning, result.Status, "99.99% usage should be warning") - assert.Equal(t, int64(1), result.RemainingTokens) -} - -func TestAllowanceExhaustion_ZeroAllowance(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - // DailyTokenLimit=0 means unlimited. - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "unlimited-agent", - DailyTokenLimit: 0, - DailyJobLimit: 0, - ConcurrentJobs: 0, - }) - _ = store.IncrementUsage("unlimited-agent", 999999999, 999) - - result, err := svc.Check("unlimited-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "unlimited agent should always be allowed") - assert.Equal(t, AllowanceOK, result.Status) - assert.Equal(t, int64(-1), result.RemainingTokens, "unlimited should show -1") - assert.Equal(t, -1, result.RemainingJobs, "unlimited should show -1") -} - -func TestAllowanceExhaustion_ExactlyAtJobLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - DailyJobLimit: 5, - }) - _ = store.IncrementUsage("edge-agent", 0, 5) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied at exactly the job limit") - assert.Equal(t, AllowanceExceeded, result.Status) - assert.Equal(t, 0, result.RemainingJobs) - assert.Contains(t, result.Reason, "daily job limit exceeded") -} - -func TestAllowanceExhaustion_OneUnderJobLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - DailyJobLimit: 5, - }) - _ = store.IncrementUsage("edge-agent", 0, 4) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed with 1 job remaining") - assert.Equal(t, 1, result.RemainingJobs) -} - -func TestAllowanceExhaustion_ConcurrentJobsExactlyAtLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - ConcurrentJobs: 2, - }) - // Start 2 concurrent jobs. - _ = store.IncrementUsage("edge-agent", 0, 2) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied at concurrent limit") - assert.Contains(t, result.Reason, "concurrent job limit reached") -} - -func TestAllowanceExhaustion_ConcurrentJobsOneUnderLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - ConcurrentJobs: 3, - }) - _ = store.IncrementUsage("edge-agent", 0, 2) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed with 1 concurrent slot remaining") -} - -func TestAllowanceExhaustion_ConcurrentJobsFreedByCompletion(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "edge-agent", - ConcurrentJobs: 1, - }) - - // Start a job - fills the slot. - _ = svc.RecordUsage(UsageReport{ - AgentID: "edge-agent", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - result, err := svc.Check("edge-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied, 1 active job") - - // Complete the job - frees the slot. - _ = svc.RecordUsage(UsageReport{ - AgentID: "edge-agent", - JobID: "job-1", - TokensIn: 100, - TokensOut: 50, - Event: QuotaEventJobCompleted, - }) - - result, err = svc.Check("edge-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed after job completes") -} - -func TestAllowanceExhaustion_TokenWarningThreshold(t *testing.T) { - tests := []struct { - name string - limit int64 - used int64 - expectedStatus AllowanceStatus - expectedAllow bool - }{ - { - name: "79% usage is OK", - limit: 10000, - used: 7900, - expectedStatus: AllowanceOK, - expectedAllow: true, - }, - { - name: "80% usage is warning", - limit: 10000, - used: 8000, - expectedStatus: AllowanceWarning, - expectedAllow: true, - }, - { - name: "90% usage is warning", - limit: 10000, - used: 9000, - expectedStatus: AllowanceWarning, - expectedAllow: true, - }, - { - name: "99% usage is warning", - limit: 10000, - used: 9999, - expectedStatus: AllowanceWarning, - expectedAllow: true, - }, - { - name: "100% usage is exceeded", - limit: 10000, - used: 10000, - expectedStatus: AllowanceExceeded, - expectedAllow: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "threshold-agent", - DailyTokenLimit: tt.limit, - }) - _ = store.IncrementUsage("threshold-agent", tt.used, 0) - - result, err := svc.Check("threshold-agent", "") - require.NoError(t, err) - assert.Equal(t, tt.expectedAllow, result.Allowed) - assert.Equal(t, tt.expectedStatus, result.Status) - }) - } -} - -func TestAllowanceExhaustion_ResetRestoresCapacity(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "reset-agent", - DailyTokenLimit: 10000, - DailyJobLimit: 5, - }) - - // Exhaust all limits. - _ = store.IncrementUsage("reset-agent", 10000, 5) - - result, err := svc.Check("reset-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied when exhausted") - - // Reset the agent (simulates midnight reset). - err = svc.ResetAgent("reset-agent") - require.NoError(t, err) - - result, err = svc.Check("reset-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed after reset") - assert.Equal(t, int64(10000), result.RemainingTokens) - assert.Equal(t, 5, result.RemainingJobs) -} - -func TestAllowanceExhaustion_GlobalModelBudgetExactlyAtLimit(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "model-edge-agent", - }) - store.SetModelQuota(&ModelQuota{ - Model: "claude-opus-4-6", - DailyTokenBudget: 50000, - }) - _ = store.IncrementModelUsage("claude-opus-4-6", 50000) - - result, err := svc.Check("model-edge-agent", "claude-opus-4-6") - require.NoError(t, err) - assert.False(t, result.Allowed, "should be denied at exact model budget") - assert.Contains(t, result.Reason, "global model token budget exceeded") -} - -func TestAllowanceExhaustion_GlobalModelBudgetOneUnder(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "model-edge-agent", - }) - store.SetModelQuota(&ModelQuota{ - Model: "claude-opus-4-6", - DailyTokenBudget: 50000, - }) - _ = store.IncrementModelUsage("claude-opus-4-6", 49999) - - result, err := svc.Check("model-edge-agent", "claude-opus-4-6") - require.NoError(t, err) - assert.True(t, result.Allowed, "should be allowed with 1 token remaining in model budget") -} - -func TestAllowanceExhaustion_FailedJobWithZeroTokens(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "zero-token-agent", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - // Job fails but consumed zero tokens. - err := svc.RecordUsage(UsageReport{ - AgentID: "zero-token-agent", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 0, - TokensOut: 0, - Event: QuotaEventJobFailed, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("zero-token-agent") - assert.Equal(t, int64(0), usage.TokensUsed, "no tokens should be charged") - assert.Equal(t, 0, usage.ActiveJobs) - - // Model usage should be zero too (50% of 0 = 0). - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(0), modelUsage) -} - -func TestAllowanceExhaustion_CancelledJobWithZeroTokens(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "zero-token-agent", - JobID: "job-2", - Event: QuotaEventJobStarted, - }) - - // Job cancelled with zero tokens. - err := svc.RecordUsage(UsageReport{ - AgentID: "zero-token-agent", - JobID: "job-2", - TokensIn: 0, - TokensOut: 0, - Event: QuotaEventJobCancelled, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("zero-token-agent") - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) -} - -func TestAllowanceExhaustion_CompletedJobWithNoModel(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "no-model-agent", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - // Complete with empty model -- should skip model-level usage recording. - err := svc.RecordUsage(UsageReport{ - AgentID: "no-model-agent", - JobID: "job-1", - Model: "", - TokensIn: 500, - TokensOut: 200, - Event: QuotaEventJobCompleted, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("no-model-agent") - assert.Equal(t, int64(700), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) -} - -func TestAllowanceExhaustion_FailedJobWithNoModel(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "no-model-fail-agent", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - // Fail with empty model. - err := svc.RecordUsage(UsageReport{ - AgentID: "no-model-fail-agent", - JobID: "job-1", - Model: "", - TokensIn: 600, - TokensOut: 400, - Event: QuotaEventJobFailed, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("no-model-fail-agent") - // 1000 tokens used, 500 returned = 500 net. - assert.Equal(t, int64(500), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) -} - -func TestAllowanceExhaustion_MultipleChecksWithIncrementalUsage(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "incremental-agent", - DailyTokenLimit: 1000, - }) - - // First check: fresh agent. - result, err := svc.Check("incremental-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceOK, result.Status) - assert.Equal(t, int64(1000), result.RemainingTokens) - - // Use 500 tokens. - _ = store.IncrementUsage("incremental-agent", 500, 0) - - result, err = svc.Check("incremental-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceOK, result.Status) - assert.Equal(t, int64(500), result.RemainingTokens) - - // Use another 300 tokens (total 800, at 80% threshold). - _ = store.IncrementUsage("incremental-agent", 300, 0) - - result, err = svc.Check("incremental-agent", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceWarning, result.Status) - assert.Equal(t, int64(200), result.RemainingTokens) - - // Use remaining 200 tokens (total 1000, at 100%). - _ = store.IncrementUsage("incremental-agent", 200, 0) - - result, err = svc.Check("incremental-agent", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Equal(t, AllowanceExceeded, result.Status) - assert.Equal(t, int64(0), result.RemainingTokens) -} - -// --- MemoryStore additional edge cases --- - -func TestMemoryStore_GetUsage_NewAgentReturnsDefaults(t *testing.T) { - store := NewMemoryStore() - - usage, err := store.GetUsage("brand-new-agent") - require.NoError(t, err) - assert.Equal(t, "brand-new-agent", usage.AgentID) - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.JobsStarted) - assert.Equal(t, 0, usage.ActiveJobs) - assert.Equal(t, startOfDay(time.Now().UTC()), usage.PeriodStart) -} - -func TestMemoryStore_ReturnTokens_NonexistentAgent(t *testing.T) { - store := NewMemoryStore() - - // ReturnTokens on a nonexistent agent should be a no-op. - err := store.ReturnTokens("ghost-agent", 5000) - require.NoError(t, err) -} - -func TestMemoryStore_DecrementActiveJobs_NonexistentAgent(t *testing.T) { - store := NewMemoryStore() - - // DecrementActiveJobs on a nonexistent agent should be a no-op. - err := store.DecrementActiveJobs("ghost-agent") - require.NoError(t, err) -} - -func TestMemoryStore_GetModelQuota_NotFound(t *testing.T) { - store := NewMemoryStore() - - _, err := store.GetModelQuota("nonexistent-model") - require.Error(t, err) - assert.Contains(t, err.Error(), "model quota not found") -} - -func TestMemoryStore_GetModelUsage_NewModelReturnsZero(t *testing.T) { - store := NewMemoryStore() - - usage, err := store.GetModelUsage("brand-new-model") - require.NoError(t, err) - assert.Equal(t, int64(0), usage) -} - -func TestMemoryStore_SetAllowance_Overwrite(t *testing.T) { - store := NewMemoryStore() - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "overwrite-agent", - DailyTokenLimit: 5000, - }) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "overwrite-agent", - DailyTokenLimit: 9000, - }) - - a, err := store.GetAllowance("overwrite-agent") - require.NoError(t, err) - assert.Equal(t, int64(9000), a.DailyTokenLimit, "should have overwritten the old allowance") -} - -func TestMemoryStore_SetAllowance_IsolatesOriginal(t *testing.T) { - store := NewMemoryStore() - - original := &AgentAllowance{ - AgentID: "isolated-agent", - DailyTokenLimit: 5000, - } - _ = store.SetAllowance(original) - - // Mutate the original. - original.DailyTokenLimit = 99999 - - got, err := store.GetAllowance("isolated-agent") - require.NoError(t, err) - assert.Equal(t, int64(5000), got.DailyTokenLimit, "store should hold a copy, not the original") -} - -func TestMemoryStore_GetAllowance_IsolatesReturn(t *testing.T) { - store := NewMemoryStore() - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "isolated-agent", - DailyTokenLimit: 5000, - }) - - got1, _ := store.GetAllowance("isolated-agent") - got1.DailyTokenLimit = 99999 - - got2, _ := store.GetAllowance("isolated-agent") - assert.Equal(t, int64(5000), got2.DailyTokenLimit, "returned value should be a copy") -} - -func TestMemoryStore_IncrementUsage_MultipleIncrements(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("multi-agent", 100, 1) - _ = store.IncrementUsage("multi-agent", 200, 1) - _ = store.IncrementUsage("multi-agent", 300, 0) - - usage, err := store.GetUsage("multi-agent") - require.NoError(t, err) - assert.Equal(t, int64(600), usage.TokensUsed) - assert.Equal(t, 2, usage.JobsStarted) - assert.Equal(t, 2, usage.ActiveJobs) -} - -func TestMemoryStore_IncrementUsage_ZeroJobsDoesNotIncrementActive(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("token-only-agent", 5000, 0) - - usage, err := store.GetUsage("token-only-agent") - require.NoError(t, err) - assert.Equal(t, int64(5000), usage.TokensUsed) - assert.Equal(t, 0, usage.JobsStarted) - assert.Equal(t, 0, usage.ActiveJobs, "zero jobs should not increment active count") -} - -// --- AllowanceService Check priority ordering --- - -func TestAllowanceServiceCheck_ModelAllowlistCheckedFirst(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - // Agent is over token limit AND using a disallowed model. - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "order-agent", - DailyTokenLimit: 1000, - ModelAllowlist: []string{"claude-haiku"}, - }) - _ = store.IncrementUsage("order-agent", 2000, 0) - - result, err := svc.Check("order-agent", "claude-opus-4-6") - require.NoError(t, err) - assert.False(t, result.Allowed) - // Model allowlist is checked first in the code, so it should be the reason. - assert.Contains(t, result.Reason, "model not in allowlist") -} - -func TestAllowanceServiceCheck_EmptyModelAllowlistPermitsAll(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "any-model-agent", - ModelAllowlist: nil, - }) - - result, err := svc.Check("any-model-agent", "any-model-at-all") - require.NoError(t, err) - assert.True(t, result.Allowed) -} - -// --- QuotaEvent constants --- - -func TestQuotaEvent_Values(t *testing.T) { - assert.Equal(t, QuotaEvent("job_started"), QuotaEventJobStarted) - assert.Equal(t, QuotaEvent("job_completed"), QuotaEventJobCompleted) - assert.Equal(t, QuotaEvent("job_failed"), QuotaEventJobFailed) - assert.Equal(t, QuotaEvent("job_cancelled"), QuotaEventJobCancelled) -} - -func TestAllowanceExhaustion_FailedJobWithOddTokenCount(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "odd-agent", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - // Odd total: 7 tokens. 50% return = 3 (integer division). - err := svc.RecordUsage(UsageReport{ - AgentID: "odd-agent", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 4, - TokensOut: 3, - Event: QuotaEventJobFailed, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("odd-agent") - // 7 charged - 3 returned = 4 net. - assert.Equal(t, int64(4), usage.TokensUsed) - - // Model gets 7 - 3 = 4. - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(4), modelUsage) -} diff --git a/pkg/lifecycle/allowance_error_test.go b/pkg/lifecycle/allowance_error_test.go deleted file mode 100644 index 63eca7c..0000000 --- a/pkg/lifecycle/allowance_error_test.go +++ /dev/null @@ -1,272 +0,0 @@ -package lifecycle - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// errorStore is a mock AllowanceStore that returns errors for specific operations. -type errorStore struct { - *MemoryStore - failIncrementUsage bool - failDecrementActive bool - failReturnTokens bool - failIncrementModel bool - failGetAllowance bool - failGetUsage bool -} - -func newErrorStore() *errorStore { - return &errorStore{MemoryStore: NewMemoryStore()} -} - -func (e *errorStore) GetAllowance(agentID string) (*AgentAllowance, error) { - if e.failGetAllowance { - return nil, errors.New("simulated GetAllowance error") - } - return e.MemoryStore.GetAllowance(agentID) -} - -func (e *errorStore) GetUsage(agentID string) (*UsageRecord, error) { - if e.failGetUsage { - return nil, errors.New("simulated GetUsage error") - } - return e.MemoryStore.GetUsage(agentID) -} - -func (e *errorStore) IncrementUsage(agentID string, tokens int64, jobs int) error { - if e.failIncrementUsage { - return errors.New("simulated IncrementUsage error") - } - return e.MemoryStore.IncrementUsage(agentID, tokens, jobs) -} - -func (e *errorStore) DecrementActiveJobs(agentID string) error { - if e.failDecrementActive { - return errors.New("simulated DecrementActiveJobs error") - } - return e.MemoryStore.DecrementActiveJobs(agentID) -} - -func (e *errorStore) ReturnTokens(agentID string, tokens int64) error { - if e.failReturnTokens { - return errors.New("simulated ReturnTokens error") - } - return e.MemoryStore.ReturnTokens(agentID, tokens) -} - -func (e *errorStore) IncrementModelUsage(model string, tokens int64) error { - if e.failIncrementModel { - return errors.New("simulated IncrementModelUsage error") - } - return e.MemoryStore.IncrementModelUsage(model, tokens) -} - -// --- RecordUsage error path tests --- - -func TestRecordUsage_Bad_JobStarted_IncrementFails(t *testing.T) { - store := newErrorStore() - store.failIncrementUsage = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - Event: QuotaEventJobStarted, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to increment job count") -} - -func TestRecordUsage_Bad_JobCompleted_IncrementFails(t *testing.T) { - store := newErrorStore() - store.failIncrementUsage = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 50, - Event: QuotaEventJobCompleted, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to record token usage") -} - -func TestRecordUsage_Bad_JobCompleted_DecrementFails(t *testing.T) { - store := newErrorStore() - store.failDecrementActive = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 50, - Event: QuotaEventJobCompleted, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to decrement active jobs") -} - -func TestRecordUsage_Bad_JobCompleted_ModelUsageFails(t *testing.T) { - store := newErrorStore() - store.failIncrementModel = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - Model: "claude-sonnet", - TokensIn: 100, - TokensOut: 50, - Event: QuotaEventJobCompleted, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to record model usage") -} - -func TestRecordUsage_Bad_JobFailed_IncrementFails(t *testing.T) { - store := newErrorStore() - store.failIncrementUsage = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobFailed, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to record token usage") -} - -func TestRecordUsage_Bad_JobFailed_DecrementFails(t *testing.T) { - store := newErrorStore() - store.failDecrementActive = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobFailed, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to decrement active jobs") -} - -func TestRecordUsage_Bad_JobFailed_ReturnTokensFails(t *testing.T) { - store := newErrorStore() - store.failReturnTokens = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobFailed, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to return tokens") -} - -func TestRecordUsage_Bad_JobFailed_ModelUsageFails(t *testing.T) { - store := newErrorStore() - store.failIncrementModel = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - Model: "claude-sonnet", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobFailed, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to record model usage") -} - -func TestRecordUsage_Bad_JobCancelled_DecrementFails(t *testing.T) { - store := newErrorStore() - store.failDecrementActive = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobCancelled, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to decrement active jobs") -} - -func TestRecordUsage_Bad_JobCancelled_ReturnTokensFails(t *testing.T) { - store := newErrorStore() - store.failReturnTokens = true - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - TokensIn: 100, - TokensOut: 100, - Event: QuotaEventJobCancelled, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to return tokens") -} - -// --- Check error path tests --- - -func TestCheck_Bad_GetAllowanceFails(t *testing.T) { - store := newErrorStore() - store.failGetAllowance = true - svc := NewAllowanceService(store) - - _, err := svc.Check("agent-1", "") - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get allowance") -} - -func TestCheck_Bad_GetUsageFails(t *testing.T) { - store := newErrorStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - }) - store.failGetUsage = true - - _, err := svc.Check("agent-1", "") - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to get usage") -} - -// --- ResetAgent error path --- - -func TestResetAgent_Bad_ResetFails(t *testing.T) { - // MemoryStore.ResetUsage never fails, but we can test the service - // layer still returns nil for the happy path (already tested). - // For a true error test, we'd need a mock, but the MemoryStore - // never errors on ResetUsage. This confirms the pattern. - store := NewMemoryStore() - svc := NewAllowanceService(store) - - err := svc.ResetAgent("nonexistent-agent") - require.NoError(t, err, "resetting a nonexistent agent should succeed") -} - -// --- RecordUsage with unknown event type --- - -func TestRecordUsage_Good_UnknownEvent(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - // Unknown event should be a no-op (falls through the switch). - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - Event: QuotaEvent("unknown_event"), - }) - require.NoError(t, err, "unknown event should not error") -} diff --git a/pkg/lifecycle/allowance_redis.go b/pkg/lifecycle/allowance_redis.go deleted file mode 100644 index cbdd119..0000000 --- a/pkg/lifecycle/allowance_redis.go +++ /dev/null @@ -1,409 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "errors" - "iter" - "time" - - "github.com/redis/go-redis/v9" -) - -// RedisStore implements AllowanceStore using Redis as the backing store. -// It provides persistent, network-accessible storage suitable for multi-node -// deployments where agents share quota state. -type RedisStore struct { - client *redis.Client - prefix string -} - -// Allowances returns an iterator over all agent allowances. -func (r *RedisStore) Allowances() iter.Seq[*AgentAllowance] { - return func(yield func(*AgentAllowance) bool) { - ctx := context.Background() - pattern := r.prefix + ":allowance:*" - iter := r.client.Scan(ctx, 0, pattern, 100).Iterator() - for iter.Next(ctx) { - val, err := r.client.Get(ctx, iter.Val()).Result() - if err != nil { - continue - } - var aj allowanceJSON - if err := json.Unmarshal([]byte(val), &aj); err != nil { - continue - } - if !yield(aj.toAgentAllowance()) { - return - } - } - } -} - -// Usages returns an iterator over all usage records. -func (r *RedisStore) Usages() iter.Seq[*UsageRecord] { - return func(yield func(*UsageRecord) bool) { - ctx := context.Background() - pattern := r.prefix + ":usage:*" - iter := r.client.Scan(ctx, 0, pattern, 100).Iterator() - for iter.Next(ctx) { - val, err := r.client.Get(ctx, iter.Val()).Result() - if err != nil { - continue - } - var u UsageRecord - if err := json.Unmarshal([]byte(val), &u); err != nil { - continue - } - if !yield(&u) { - return - } - } - } -} - -// redisConfig holds the configuration for a RedisStore. -type redisConfig struct { - password string - db int - prefix string -} - -// RedisOption is a functional option for configuring a RedisStore. -type RedisOption func(*redisConfig) - -// WithRedisPassword sets the password for authenticating with Redis. -func WithRedisPassword(pw string) RedisOption { - return func(c *redisConfig) { - c.password = pw - } -} - -// WithRedisDB selects the Redis database number. -func WithRedisDB(db int) RedisOption { - return func(c *redisConfig) { - c.db = db - } -} - -// WithRedisPrefix sets the key prefix for all Redis keys. Default: "agentic". -func WithRedisPrefix(prefix string) RedisOption { - return func(c *redisConfig) { - c.prefix = prefix - } -} - -// NewRedisStore creates a new Redis-backed allowance store connecting to the -// given address (host:port). It pings the server to verify connectivity. -func NewRedisStore(addr string, opts ...RedisOption) (*RedisStore, error) { - cfg := &redisConfig{ - prefix: "agentic", - } - for _, opt := range opts { - opt(cfg) - } - - client := redis.NewClient(&redis.Options{ - Addr: addr, - Password: cfg.password, - DB: cfg.db, - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := client.Ping(ctx).Err(); err != nil { - _ = client.Close() - return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()} - } - - return &RedisStore{ - client: client, - prefix: cfg.prefix, - }, nil -} - -// Close releases the underlying Redis connection. -func (r *RedisStore) Close() error { - return r.client.Close() -} - -// --- key helpers --- - -func (r *RedisStore) allowanceKey(agentID string) string { - return r.prefix + ":allowance:" + agentID -} - -func (r *RedisStore) usageKey(agentID string) string { - return r.prefix + ":usage:" + agentID -} - -func (r *RedisStore) modelQuotaKey(model string) string { - return r.prefix + ":model_quota:" + model -} - -func (r *RedisStore) modelUsageKey(model string) string { - return r.prefix + ":model_usage:" + model -} - -// --- AllowanceStore interface --- - -// GetAllowance returns the quota limits for an agent. -func (r *RedisStore) GetAllowance(agentID string) (*AgentAllowance, error) { - ctx := context.Background() - val, err := r.client.Get(ctx, r.allowanceKey(agentID)).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID} - } - return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()} - } - var aj allowanceJSON - if err := json.Unmarshal([]byte(val), &aj); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()} - } - return aj.toAgentAllowance(), nil -} - -// SetAllowance persists quota limits for an agent. -func (r *RedisStore) SetAllowance(a *AgentAllowance) error { - ctx := context.Background() - aj := newAllowanceJSON(a) - data, err := json.Marshal(aj) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()} - } - if err := r.client.Set(ctx, r.allowanceKey(a.AgentID), data, 0).Err(); err != nil { - return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()} - } - return nil -} - -// GetUsage returns the current usage record for an agent. -func (r *RedisStore) GetUsage(agentID string) (*UsageRecord, error) { - ctx := context.Background() - val, err := r.client.Get(ctx, r.usageKey(agentID)).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - }, nil - } - return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()} - } - var u UsageRecord - if err := json.Unmarshal([]byte(val), &u); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()} - } - return &u, nil -} - -// incrementUsageLua atomically reads, increments, and writes back a usage record. -// KEYS[1] = usage key -// ARGV[1] = tokens to add (int64) -// ARGV[2] = jobs to add (int) -// ARGV[3] = agent ID (for creating a new record) -// ARGV[4] = period start ISO string (for creating a new record) -var incrementUsageLua = redis.NewScript(` -local val = redis.call('GET', KEYS[1]) -local u -if val then - u = cjson.decode(val) -else - u = { - agent_id = ARGV[3], - tokens_used = 0, - jobs_started = 0, - active_jobs = 0, - period_start = ARGV[4] - } -end -u.tokens_used = u.tokens_used + tonumber(ARGV[1]) -u.jobs_started = u.jobs_started + tonumber(ARGV[2]) -if tonumber(ARGV[2]) > 0 then - u.active_jobs = u.active_jobs + tonumber(ARGV[2]) -end -redis.call('SET', KEYS[1], cjson.encode(u)) -return 'OK' -`) - -// IncrementUsage atomically adds to an agent's usage counters. -func (r *RedisStore) IncrementUsage(agentID string, tokens int64, jobs int) error { - ctx := context.Background() - periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) - err := incrementUsageLua.Run(ctx, r.client, - []string{r.usageKey(agentID)}, - tokens, jobs, agentID, periodStart, - ).Err() - if err != nil { - return &APIError{Code: 500, Message: "failed to increment usage: " + err.Error()} - } - return nil -} - -// decrementActiveJobsLua atomically decrements the active jobs counter, flooring at zero. -// KEYS[1] = usage key -// ARGV[1] = agent ID -// ARGV[2] = period start ISO string -var decrementActiveJobsLua = redis.NewScript(` -local val = redis.call('GET', KEYS[1]) -if not val then - return 'OK' -end -local u = cjson.decode(val) -if u.active_jobs and u.active_jobs > 0 then - u.active_jobs = u.active_jobs - 1 -end -redis.call('SET', KEYS[1], cjson.encode(u)) -return 'OK' -`) - -// DecrementActiveJobs reduces the active job count by 1. -func (r *RedisStore) DecrementActiveJobs(agentID string) error { - ctx := context.Background() - periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) - err := decrementActiveJobsLua.Run(ctx, r.client, - []string{r.usageKey(agentID)}, - agentID, periodStart, - ).Err() - if err != nil { - return &APIError{Code: 500, Message: "failed to decrement active jobs: " + err.Error()} - } - return nil -} - -// returnTokensLua atomically subtracts tokens from usage, flooring at zero. -// KEYS[1] = usage key -// ARGV[1] = tokens to return (int64) -// ARGV[2] = agent ID -// ARGV[3] = period start ISO string -var returnTokensLua = redis.NewScript(` -local val = redis.call('GET', KEYS[1]) -if not val then - return 'OK' -end -local u = cjson.decode(val) -u.tokens_used = u.tokens_used - tonumber(ARGV[1]) -if u.tokens_used < 0 then - u.tokens_used = 0 -end -redis.call('SET', KEYS[1], cjson.encode(u)) -return 'OK' -`) - -// ReturnTokens adds tokens back to the agent's remaining quota. -func (r *RedisStore) ReturnTokens(agentID string, tokens int64) error { - ctx := context.Background() - periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) - err := returnTokensLua.Run(ctx, r.client, - []string{r.usageKey(agentID)}, - tokens, agentID, periodStart, - ).Err() - if err != nil { - return &APIError{Code: 500, Message: "failed to return tokens: " + err.Error()} - } - return nil -} - -// ResetUsage clears usage counters for an agent (daily reset). -func (r *RedisStore) ResetUsage(agentID string) error { - ctx := context.Background() - u := &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - } - data, err := json.Marshal(u) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()} - } - if err := r.client.Set(ctx, r.usageKey(agentID), data, 0).Err(); err != nil { - return &APIError{Code: 500, Message: "failed to reset usage: " + err.Error()} - } - return nil -} - -// GetModelQuota returns global limits for a model. -func (r *RedisStore) GetModelQuota(model string) (*ModelQuota, error) { - ctx := context.Background() - val, err := r.client.Get(ctx, r.modelQuotaKey(model)).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return nil, &APIError{Code: 404, Message: "model quota not found: " + model} - } - return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()} - } - var q ModelQuota - if err := json.Unmarshal([]byte(val), &q); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()} - } - return &q, nil -} - -// GetModelUsage returns current token usage for a model. -func (r *RedisStore) GetModelUsage(model string) (int64, error) { - ctx := context.Background() - val, err := r.client.Get(ctx, r.modelUsageKey(model)).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return 0, nil - } - return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()} - } - var tokens int64 - if err := json.Unmarshal([]byte(val), &tokens); err != nil { - return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()} - } - return tokens, nil -} - -// incrementModelUsageLua atomically increments the model usage counter. -// KEYS[1] = model usage key -// ARGV[1] = tokens to add -var incrementModelUsageLua = redis.NewScript(` -local val = redis.call('GET', KEYS[1]) -local current = 0 -if val then - current = tonumber(val) -end -current = current + tonumber(ARGV[1]) -redis.call('SET', KEYS[1], tostring(current)) -return current -`) - -// IncrementModelUsage atomically adds to a model's usage counter. -func (r *RedisStore) IncrementModelUsage(model string, tokens int64) error { - ctx := context.Background() - err := incrementModelUsageLua.Run(ctx, r.client, - []string{r.modelUsageKey(model)}, - tokens, - ).Err() - if err != nil { - return &APIError{Code: 500, Message: "failed to increment model usage: " + err.Error()} - } - return nil -} - -// SetModelQuota persists global limits for a model. -func (r *RedisStore) SetModelQuota(q *ModelQuota) error { - ctx := context.Background() - data, err := json.Marshal(q) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()} - } - if err := r.client.Set(ctx, r.modelQuotaKey(q.Model), data, 0).Err(); err != nil { - return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()} - } - return nil -} - -// FlushPrefix deletes all keys matching the store's prefix. Useful for testing cleanup. -func (r *RedisStore) FlushPrefix(ctx context.Context) error { - iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator() - for iter.Next(ctx) { - if err := r.client.Del(ctx, iter.Val()).Err(); err != nil { - return err - } - } - return iter.Err() -} diff --git a/pkg/lifecycle/allowance_redis_test.go b/pkg/lifecycle/allowance_redis_test.go deleted file mode 100644 index 3e026cb..0000000 --- a/pkg/lifecycle/allowance_redis_test.go +++ /dev/null @@ -1,454 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -const testRedisAddr = "10.69.69.87:6379" - -// newTestRedisStore creates a RedisStore with a unique prefix for test isolation. -// Skips the test if Redis is unreachable. -func newTestRedisStore(t *testing.T) *RedisStore { - t.Helper() - prefix := fmt.Sprintf("test_%d", time.Now().UnixNano()) - s, err := NewRedisStore(testRedisAddr, WithRedisPrefix(prefix)) - if err != nil { - t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) - } - t.Cleanup(func() { - ctx := context.Background() - _ = s.FlushPrefix(ctx) - _ = s.Close() - }) - return s -} - -// --- SetAllowance / GetAllowance --- - -func TestRedisStore_SetGetAllowance_Good(t *testing.T) { - s := newTestRedisStore(t) - - a := &AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - MaxJobDuration: 30 * time.Minute, - ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, - } - - err := s.SetAllowance(a) - require.NoError(t, err) - - got, err := s.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, a.AgentID, got.AgentID) - assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit) - assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit) - assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs) - assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration) - assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist) -} - -func TestRedisStore_GetAllowance_Bad_NotFound(t *testing.T) { - s := newTestRedisStore(t) - _, err := s.GetAllowance("nonexistent") - require.Error(t, err) - apiErr, ok := err.(*APIError) - require.True(t, ok, "expected *APIError") - assert.Equal(t, 404, apiErr.Code) - assert.Contains(t, err.Error(), "allowance not found") -} - -func TestRedisStore_SetAllowance_Good_Overwrite(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100}) - _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200}) - - got, err := s.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(200), got.DailyTokenLimit) -} - -// --- GetUsage / IncrementUsage --- - -func TestRedisStore_GetUsage_Good_Default(t *testing.T) { - s := newTestRedisStore(t) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", u.AgentID) - assert.Equal(t, int64(0), u.TokensUsed) - assert.Equal(t, 0, u.JobsStarted) - assert.Equal(t, 0, u.ActiveJobs) -} - -func TestRedisStore_IncrementUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - - err := s.IncrementUsage("agent-1", 5000, 1) - require.NoError(t, err) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(5000), u.TokensUsed) - assert.Equal(t, 1, u.JobsStarted) - assert.Equal(t, 1, u.ActiveJobs) -} - -func TestRedisStore_IncrementUsage_Good_Accumulates(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementUsage("agent-1", 1000, 1) - _ = s.IncrementUsage("agent-1", 2000, 1) - _ = s.IncrementUsage("agent-1", 3000, 0) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(6000), u.TokensUsed) - assert.Equal(t, 2, u.JobsStarted) - assert.Equal(t, 2, u.ActiveJobs) -} - -// --- DecrementActiveJobs --- - -func TestRedisStore_DecrementActiveJobs_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementUsage("agent-1", 0, 2) - _ = s.DecrementActiveJobs("agent-1") - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, 1, u.ActiveJobs) -} - -func TestRedisStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.DecrementActiveJobs("agent-1") // no usage record yet - _ = s.IncrementUsage("agent-1", 0, 0) - _ = s.DecrementActiveJobs("agent-1") // should stay at 0 - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- ReturnTokens --- - -func TestRedisStore_ReturnTokens_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementUsage("agent-1", 10000, 0) - err := s.ReturnTokens("agent-1", 5000) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(5000), u.TokensUsed) -} - -func TestRedisStore_ReturnTokens_Good_FloorAtZero(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementUsage("agent-1", 1000, 0) - _ = s.ReturnTokens("agent-1", 5000) // more than used - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) -} - -func TestRedisStore_ReturnTokens_Good_NoRecord(t *testing.T) { - s := newTestRedisStore(t) - - // Return tokens for agent with no usage record -- should be a no-op - err := s.ReturnTokens("agent-1", 500) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) -} - -// --- ResetUsage --- - -func TestRedisStore_ResetUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementUsage("agent-1", 50000, 5) - - err := s.ResetUsage("agent-1") - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) - assert.Equal(t, 0, u.JobsStarted) - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- ModelQuota --- - -func TestRedisStore_GetModelQuota_Bad_NotFound(t *testing.T) { - s := newTestRedisStore(t) - _, err := s.GetModelQuota("nonexistent") - require.Error(t, err) - apiErr, ok := err.(*APIError) - require.True(t, ok, "expected *APIError") - assert.Equal(t, 404, apiErr.Code) - assert.Contains(t, err.Error(), "model quota not found") -} - -func TestRedisStore_SetGetModelQuota_Good(t *testing.T) { - s := newTestRedisStore(t) - - q := &ModelQuota{ - Model: "claude-opus-4-6", - DailyTokenBudget: 500000, - HourlyRateLimit: 100, - CostCeiling: 10000, - } - err := s.SetModelQuota(q) - require.NoError(t, err) - - got, err := s.GetModelQuota("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, q.Model, got.Model) - assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget) - assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit) - assert.Equal(t, q.CostCeiling, got.CostCeiling) -} - -// --- ModelUsage --- - -func TestRedisStore_ModelUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.IncrementModelUsage("claude-sonnet", 10000) - _ = s.IncrementModelUsage("claude-sonnet", 5000) - - usage, err := s.GetModelUsage("claude-sonnet") - require.NoError(t, err) - assert.Equal(t, int64(15000), usage) -} - -func TestRedisStore_GetModelUsage_Good_Default(t *testing.T) { - s := newTestRedisStore(t) - - usage, err := s.GetModelUsage("unknown-model") - require.NoError(t, err) - assert.Equal(t, int64(0), usage) -} - -// --- Persistence: set, get, verify --- - -func TestRedisStore_Persistence_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - MaxJobDuration: 15 * time.Minute, - }) - _ = s.IncrementUsage("agent-1", 25000, 3) - _ = s.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000}) - _ = s.IncrementModelUsage("claude-opus-4-6", 42000) - - // Verify all data persists (same connection, but data is in Redis) - a, err := s.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(100000), a.DailyTokenLimit) - assert.Equal(t, 15*time.Minute, a.MaxJobDuration) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(25000), u.TokensUsed) - assert.Equal(t, 3, u.JobsStarted) - assert.Equal(t, 3, u.ActiveJobs) - - q, err := s.GetModelQuota("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, int64(500000), q.DailyTokenBudget) - - mu, err := s.GetModelUsage("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, int64(42000), mu) -} - -// --- Concurrent access --- - -func TestRedisStore_ConcurrentIncrementUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - - const goroutines = 10 - const tokensEach = 1000 - - var wg sync.WaitGroup - wg.Add(goroutines) - for range goroutines { - go func() { - defer wg.Done() - err := s.IncrementUsage("agent-1", tokensEach, 1) - assert.NoError(t, err) - }() - } - wg.Wait() - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed) - assert.Equal(t, goroutines, u.JobsStarted) - assert.Equal(t, goroutines, u.ActiveJobs) -} - -func TestRedisStore_ConcurrentModelUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - - const goroutines = 10 - const tokensEach int64 = 500 - - var wg sync.WaitGroup - wg.Add(goroutines) - for range goroutines { - go func() { - defer wg.Done() - err := s.IncrementModelUsage("claude-opus-4-6", tokensEach) - assert.NoError(t, err) - }() - } - wg.Wait() - - usage, err := s.GetModelUsage("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, goroutines*tokensEach, usage) -} - -func TestRedisStore_ConcurrentMixed_Good(t *testing.T) { - s := newTestRedisStore(t) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 1000000, - DailyJobLimit: 100, - ConcurrentJobs: 50, - }) - - const goroutines = 10 - var wg sync.WaitGroup - wg.Add(goroutines * 3) - - // Increment usage - for range goroutines { - go func() { - defer wg.Done() - _ = s.IncrementUsage("agent-1", 100, 1) - }() - } - - // Decrement active jobs - for range goroutines { - go func() { - defer wg.Done() - _ = s.DecrementActiveJobs("agent-1") - }() - } - - // Return tokens - for range goroutines { - go func() { - defer wg.Done() - _ = s.ReturnTokens("agent-1", 10) - }() - } - - wg.Wait() - - // Verify no panics and data is consistent - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.GreaterOrEqual(t, u.TokensUsed, int64(0)) - assert.GreaterOrEqual(t, u.ActiveJobs, 0) -} - -// --- AllowanceService integration via RedisStore --- - -func TestRedisStore_AllowanceServiceCheck_Good(t *testing.T) { - s := newTestRedisStore(t) - svc := NewAllowanceService(s) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - }) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceOK, result.Status) -} - -func TestRedisStore_AllowanceServiceRecordUsage_Good(t *testing.T) { - s := newTestRedisStore(t) - svc := NewAllowanceService(s) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - }) - - // Start job - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - // Complete job - err = svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 1000, - TokensOut: 500, - Event: QuotaEventJobCompleted, - }) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(1500), u.TokensUsed) - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- Config-based factory with redis backend --- - -func TestNewAllowanceStoreFromConfig_Good_Redis(t *testing.T) { - cfg := AllowanceConfig{ - StoreBackend: "redis", - RedisAddr: testRedisAddr, - } - s, err := NewAllowanceStoreFromConfig(cfg) - if err != nil { - t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) - } - rs, ok := s.(*RedisStore) - assert.True(t, ok, "expected RedisStore") - _ = rs.Close() -} - -// --- Constructor error case --- - -func TestNewRedisStore_Bad_Unreachable(t *testing.T) { - _, err := NewRedisStore("127.0.0.1:1") // almost certainly unreachable - require.Error(t, err) - apiErr, ok := err.(*APIError) - require.True(t, ok, "expected *APIError") - assert.Equal(t, 500, apiErr.Code) - assert.Contains(t, err.Error(), "failed to connect to Redis") -} diff --git a/pkg/lifecycle/allowance_service.go b/pkg/lifecycle/allowance_service.go deleted file mode 100644 index d6940a7..0000000 --- a/pkg/lifecycle/allowance_service.go +++ /dev/null @@ -1,204 +0,0 @@ -package lifecycle - -import ( - "context" - "slices" - "time" - - "forge.lthn.ai/core/go-log" -) - -// AllowanceService enforces agent quota limits. It provides pre-dispatch checks, -// runtime usage recording, and quota recovery for failed/cancelled jobs. -type AllowanceService struct { - store AllowanceStore - events EventEmitter -} - -// NewAllowanceService creates a new AllowanceService with the given store. -func NewAllowanceService(store AllowanceStore) *AllowanceService { - return &AllowanceService{store: store} -} - -// SetEventEmitter attaches an event emitter for quota lifecycle notifications. -func (s *AllowanceService) SetEventEmitter(em EventEmitter) { - s.events = em -} - -// emitEvent is a convenience helper that publishes an event if an emitter is set. -func (s *AllowanceService) emitEvent(eventType EventType, agentID string, payload any) { - if s.events != nil { - _ = s.events.Emit(context.Background(), Event{ - Type: eventType, - AgentID: agentID, - Timestamp: time.Now().UTC(), - Payload: payload, - }) - } -} - -// Check performs a pre-dispatch allowance check for the given agent and model. -// It verifies daily token limits, daily job limits, concurrent job limits, and -// model allowlists. Returns a QuotaCheckResult indicating whether the agent may proceed. -func (s *AllowanceService) Check(agentID, model string) (*QuotaCheckResult, error) { - const op = "AllowanceService.Check" - - allowance, err := s.store.GetAllowance(agentID) - if err != nil { - return nil, log.E(op, "failed to get allowance", err) - } - - usage, err := s.store.GetUsage(agentID) - if err != nil { - return nil, log.E(op, "failed to get usage", err) - } - - result := &QuotaCheckResult{ - Allowed: true, - Status: AllowanceOK, - RemainingTokens: -1, // unlimited - RemainingJobs: -1, // unlimited - } - - // Check model allowlist - if len(allowance.ModelAllowlist) > 0 && model != "" { - if !slices.Contains(allowance.ModelAllowlist, model) { - result.Allowed = false - result.Status = AllowanceExceeded - result.Reason = "model not in allowlist: " + model - s.emitEvent(EventQuotaExceeded, agentID, result.Reason) - return result, nil - } - } - - // Check daily token limit - if allowance.DailyTokenLimit > 0 { - remaining := allowance.DailyTokenLimit - usage.TokensUsed - result.RemainingTokens = remaining - if remaining <= 0 { - result.Allowed = false - result.Status = AllowanceExceeded - result.Reason = "daily token limit exceeded" - s.emitEvent(EventQuotaExceeded, agentID, result.Reason) - return result, nil - } - ratio := float64(usage.TokensUsed) / float64(allowance.DailyTokenLimit) - if ratio >= 0.8 { - result.Status = AllowanceWarning - s.emitEvent(EventQuotaWarning, agentID, ratio) - } - } - - // Check daily job limit - if allowance.DailyJobLimit > 0 { - remaining := allowance.DailyJobLimit - usage.JobsStarted - result.RemainingJobs = remaining - if remaining <= 0 { - result.Allowed = false - result.Status = AllowanceExceeded - result.Reason = "daily job limit exceeded" - s.emitEvent(EventQuotaExceeded, agentID, result.Reason) - return result, nil - } - } - - // Check concurrent jobs - if allowance.ConcurrentJobs > 0 && usage.ActiveJobs >= allowance.ConcurrentJobs { - result.Allowed = false - result.Status = AllowanceExceeded - result.Reason = "concurrent job limit reached" - s.emitEvent(EventQuotaExceeded, agentID, result.Reason) - return result, nil - } - - // Check global model quota - if model != "" { - modelQuota, err := s.store.GetModelQuota(model) - if err == nil && modelQuota.DailyTokenBudget > 0 { - modelUsage, err := s.store.GetModelUsage(model) - if err == nil && modelUsage >= modelQuota.DailyTokenBudget { - result.Allowed = false - result.Status = AllowanceExceeded - result.Reason = "global model token budget exceeded for: " + model - s.emitEvent(EventQuotaExceeded, agentID, result.Reason) - return result, nil - } - } - } - - return result, nil -} - -// RecordUsage processes a usage report, updating counters and handling quota recovery. -func (s *AllowanceService) RecordUsage(report UsageReport) error { - const op = "AllowanceService.RecordUsage" - - totalTokens := report.TokensIn + report.TokensOut - - switch report.Event { - case QuotaEventJobStarted: - if err := s.store.IncrementUsage(report.AgentID, 0, 1); err != nil { - return log.E(op, "failed to increment job count", err) - } - s.emitEvent(EventUsageRecorded, report.AgentID, report) - - case QuotaEventJobCompleted: - if err := s.store.IncrementUsage(report.AgentID, totalTokens, 0); err != nil { - return log.E(op, "failed to record token usage", err) - } - if err := s.store.DecrementActiveJobs(report.AgentID); err != nil { - return log.E(op, "failed to decrement active jobs", err) - } - // Record model-level usage - if report.Model != "" { - if err := s.store.IncrementModelUsage(report.Model, totalTokens); err != nil { - return log.E(op, "failed to record model usage", err) - } - } - s.emitEvent(EventUsageRecorded, report.AgentID, report) - - case QuotaEventJobFailed: - // Record partial usage, return 50% of tokens - if err := s.store.IncrementUsage(report.AgentID, totalTokens, 0); err != nil { - return log.E(op, "failed to record token usage", err) - } - if err := s.store.DecrementActiveJobs(report.AgentID); err != nil { - return log.E(op, "failed to decrement active jobs", err) - } - returnAmount := totalTokens / 2 - if returnAmount > 0 { - if err := s.store.ReturnTokens(report.AgentID, returnAmount); err != nil { - return log.E(op, "failed to return tokens", err) - } - } - // Still record model-level usage (net of return) - if report.Model != "" { - if err := s.store.IncrementModelUsage(report.Model, totalTokens-returnAmount); err != nil { - return log.E(op, "failed to record model usage", err) - } - } - - case QuotaEventJobCancelled: - // Return 100% of tokens - if err := s.store.DecrementActiveJobs(report.AgentID); err != nil { - return log.E(op, "failed to decrement active jobs", err) - } - if totalTokens > 0 { - if err := s.store.ReturnTokens(report.AgentID, totalTokens); err != nil { - return log.E(op, "failed to return tokens", err) - } - } - // No model-level usage for cancelled jobs - } - - return nil -} - -// ResetAgent clears daily usage counters for the given agent (midnight reset). -func (s *AllowanceService) ResetAgent(agentID string) error { - const op = "AllowanceService.ResetAgent" - if err := s.store.ResetUsage(agentID); err != nil { - return log.E(op, "failed to reset usage", err) - } - return nil -} diff --git a/pkg/lifecycle/allowance_sqlite.go b/pkg/lifecycle/allowance_sqlite.go deleted file mode 100644 index b90db5b..0000000 --- a/pkg/lifecycle/allowance_sqlite.go +++ /dev/null @@ -1,333 +0,0 @@ -package lifecycle - -import ( - "encoding/json" - "errors" - "iter" - "sync" - "time" - - "forge.lthn.ai/core/go-store" -) - -// SQLite group names for namespacing data in the KV store. -const ( - groupAllowances = "allowances" - groupUsage = "usage" - groupModelQuota = "model_quotas" - groupModelUsage = "model_usage" -) - -// SQLiteStore implements AllowanceStore using go-store (SQLite KV). -// It provides persistent storage that survives process restarts. -type SQLiteStore struct { - db *store.Store - mu sync.Mutex // serialises read-modify-write operations -} - -// Allowances returns an iterator over all agent allowances. -func (s *SQLiteStore) Allowances() iter.Seq[*AgentAllowance] { - return func(yield func(*AgentAllowance) bool) { - for kv, err := range s.db.All(groupAllowances) { - if err != nil { - continue - } - var a allowanceJSON - if err := json.Unmarshal([]byte(kv.Value), &a); err != nil { - continue - } - if !yield(a.toAgentAllowance()) { - return - } - } - } -} - -// Usages returns an iterator over all usage records. -func (s *SQLiteStore) Usages() iter.Seq[*UsageRecord] { - return func(yield func(*UsageRecord) bool) { - for kv, err := range s.db.All(groupUsage) { - if err != nil { - continue - } - var u UsageRecord - if err := json.Unmarshal([]byte(kv.Value), &u); err != nil { - continue - } - if !yield(&u) { - return - } - } - } -} - -// NewSQLiteStore creates a new SQLite-backed allowance store at the given path. -// Use ":memory:" for tests that do not need persistence. -func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { - db, err := store.New(dbPath) - if err != nil { - return nil, &APIError{Code: 500, Message: "failed to open SQLite store: " + err.Error()} - } - return &SQLiteStore{db: db}, nil -} - -// Close releases the underlying SQLite database. -func (s *SQLiteStore) Close() error { - return s.db.Close() -} - -// GetAllowance returns the quota limits for an agent. -func (s *SQLiteStore) GetAllowance(agentID string) (*AgentAllowance, error) { - val, err := s.db.Get(groupAllowances, agentID) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID} - } - return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()} - } - var a allowanceJSON - if err := json.Unmarshal([]byte(val), &a); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()} - } - return a.toAgentAllowance(), nil -} - -// SetAllowance persists quota limits for an agent. -func (s *SQLiteStore) SetAllowance(a *AgentAllowance) error { - aj := newAllowanceJSON(a) - data, err := json.Marshal(aj) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()} - } - if err := s.db.Set(groupAllowances, a.AgentID, string(data)); err != nil { - return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()} - } - return nil -} - -// GetUsage returns the current usage record for an agent. -func (s *SQLiteStore) GetUsage(agentID string) (*UsageRecord, error) { - val, err := s.db.Get(groupUsage, agentID) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - }, nil - } - return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()} - } - var u UsageRecord - if err := json.Unmarshal([]byte(val), &u); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()} - } - return &u, nil -} - -// IncrementUsage atomically adds to an agent's usage counters. -func (s *SQLiteStore) IncrementUsage(agentID string, tokens int64, jobs int) error { - s.mu.Lock() - defer s.mu.Unlock() - - u, err := s.getUsageLocked(agentID) - if err != nil { - return err - } - u.TokensUsed += tokens - u.JobsStarted += jobs - if jobs > 0 { - u.ActiveJobs += jobs - } - return s.putUsageLocked(u) -} - -// DecrementActiveJobs reduces the active job count by 1. -func (s *SQLiteStore) DecrementActiveJobs(agentID string) error { - s.mu.Lock() - defer s.mu.Unlock() - - u, err := s.getUsageLocked(agentID) - if err != nil { - return err - } - if u.ActiveJobs > 0 { - u.ActiveJobs-- - } - return s.putUsageLocked(u) -} - -// ReturnTokens adds tokens back to the agent's remaining quota. -func (s *SQLiteStore) ReturnTokens(agentID string, tokens int64) error { - s.mu.Lock() - defer s.mu.Unlock() - - u, err := s.getUsageLocked(agentID) - if err != nil { - return err - } - u.TokensUsed -= tokens - if u.TokensUsed < 0 { - u.TokensUsed = 0 - } - return s.putUsageLocked(u) -} - -// ResetUsage clears usage counters for an agent. -func (s *SQLiteStore) ResetUsage(agentID string) error { - s.mu.Lock() - defer s.mu.Unlock() - - u := &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - } - return s.putUsageLocked(u) -} - -// GetModelQuota returns global limits for a model. -func (s *SQLiteStore) GetModelQuota(model string) (*ModelQuota, error) { - val, err := s.db.Get(groupModelQuota, model) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return nil, &APIError{Code: 404, Message: "model quota not found: " + model} - } - return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()} - } - var q ModelQuota - if err := json.Unmarshal([]byte(val), &q); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()} - } - return &q, nil -} - -// GetModelUsage returns current token usage for a model. -func (s *SQLiteStore) GetModelUsage(model string) (int64, error) { - val, err := s.db.Get(groupModelUsage, model) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return 0, nil - } - return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()} - } - var tokens int64 - if err := json.Unmarshal([]byte(val), &tokens); err != nil { - return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()} - } - return tokens, nil -} - -// IncrementModelUsage atomically adds to a model's usage counter. -func (s *SQLiteStore) IncrementModelUsage(model string, tokens int64) error { - s.mu.Lock() - defer s.mu.Unlock() - - current, err := s.getModelUsageLocked(model) - if err != nil { - return err - } - current += tokens - data, err := json.Marshal(current) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal model usage: " + err.Error()} - } - if err := s.db.Set(groupModelUsage, model, string(data)); err != nil { - return &APIError{Code: 500, Message: "failed to set model usage: " + err.Error()} - } - return nil -} - -// SetModelQuota persists global limits for a model. -func (s *SQLiteStore) SetModelQuota(q *ModelQuota) error { - data, err := json.Marshal(q) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()} - } - if err := s.db.Set(groupModelQuota, q.Model, string(data)); err != nil { - return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()} - } - return nil -} - -// --- internal helpers (must be called with mu held) --- - -// getUsageLocked reads a UsageRecord from the store. Caller must hold s.mu. -func (s *SQLiteStore) getUsageLocked(agentID string) (*UsageRecord, error) { - val, err := s.db.Get(groupUsage, agentID) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return &UsageRecord{ - AgentID: agentID, - PeriodStart: startOfDay(time.Now().UTC()), - }, nil - } - return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()} - } - var u UsageRecord - if err := json.Unmarshal([]byte(val), &u); err != nil { - return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()} - } - return &u, nil -} - -// putUsageLocked writes a UsageRecord to the store. Caller must hold s.mu. -func (s *SQLiteStore) putUsageLocked(u *UsageRecord) error { - data, err := json.Marshal(u) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()} - } - if err := s.db.Set(groupUsage, u.AgentID, string(data)); err != nil { - return &APIError{Code: 500, Message: "failed to set usage: " + err.Error()} - } - return nil -} - -// getModelUsageLocked reads model usage from the store. Caller must hold s.mu. -func (s *SQLiteStore) getModelUsageLocked(model string) (int64, error) { - val, err := s.db.Get(groupModelUsage, model) - if err != nil { - if errors.Is(err, store.ErrNotFound) { - return 0, nil - } - return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()} - } - var tokens int64 - if err := json.Unmarshal([]byte(val), &tokens); err != nil { - return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()} - } - return tokens, nil -} - -// --- JSON serialisation helper for AgentAllowance --- -// time.Duration does not have a stable JSON representation. We serialise it -// as an int64 (nanoseconds) to avoid locale-dependent string parsing. - -type allowanceJSON struct { - AgentID string `json:"agent_id"` - DailyTokenLimit int64 `json:"daily_token_limit"` - DailyJobLimit int `json:"daily_job_limit"` - ConcurrentJobs int `json:"concurrent_jobs"` - MaxJobDurationNs int64 `json:"max_job_duration_ns"` - ModelAllowlist []string `json:"model_allowlist,omitempty"` -} - -func newAllowanceJSON(a *AgentAllowance) *allowanceJSON { - return &allowanceJSON{ - AgentID: a.AgentID, - DailyTokenLimit: a.DailyTokenLimit, - DailyJobLimit: a.DailyJobLimit, - ConcurrentJobs: a.ConcurrentJobs, - MaxJobDurationNs: int64(a.MaxJobDuration), - ModelAllowlist: a.ModelAllowlist, - } -} - -func (aj *allowanceJSON) toAgentAllowance() *AgentAllowance { - return &AgentAllowance{ - AgentID: aj.AgentID, - DailyTokenLimit: aj.DailyTokenLimit, - DailyJobLimit: aj.DailyJobLimit, - ConcurrentJobs: aj.ConcurrentJobs, - MaxJobDuration: time.Duration(aj.MaxJobDurationNs), - ModelAllowlist: aj.ModelAllowlist, - } -} diff --git a/pkg/lifecycle/allowance_sqlite_test.go b/pkg/lifecycle/allowance_sqlite_test.go deleted file mode 100644 index 8873599..0000000 --- a/pkg/lifecycle/allowance_sqlite_test.go +++ /dev/null @@ -1,465 +0,0 @@ -package lifecycle - -import ( - "path/filepath" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// newTestSQLiteStore creates a SQLiteStore in a temp directory. -func newTestSQLiteStore(t *testing.T) *SQLiteStore { - t.Helper() - dbPath := filepath.Join(t.TempDir(), "test.db") - s, err := NewSQLiteStore(dbPath) - require.NoError(t, err) - t.Cleanup(func() { _ = s.Close() }) - return s -} - -// --- SetAllowance / GetAllowance --- - -func TestSQLiteStore_SetGetAllowance_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - a := &AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - MaxJobDuration: 30 * time.Minute, - ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, - } - - err := s.SetAllowance(a) - require.NoError(t, err) - - got, err := s.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, a.AgentID, got.AgentID) - assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit) - assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit) - assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs) - assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration) - assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist) -} - -func TestSQLiteStore_GetAllowance_Bad_NotFound(t *testing.T) { - s := newTestSQLiteStore(t) - _, err := s.GetAllowance("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "allowance not found") -} - -func TestSQLiteStore_SetAllowance_Good_Overwrite(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100}) - _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200}) - - got, err := s.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(200), got.DailyTokenLimit) -} - -// --- GetUsage / IncrementUsage --- - -func TestSQLiteStore_GetUsage_Good_Default(t *testing.T) { - s := newTestSQLiteStore(t) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", u.AgentID) - assert.Equal(t, int64(0), u.TokensUsed) - assert.Equal(t, 0, u.JobsStarted) - assert.Equal(t, 0, u.ActiveJobs) -} - -func TestSQLiteStore_IncrementUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - err := s.IncrementUsage("agent-1", 5000, 1) - require.NoError(t, err) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(5000), u.TokensUsed) - assert.Equal(t, 1, u.JobsStarted) - assert.Equal(t, 1, u.ActiveJobs) -} - -func TestSQLiteStore_IncrementUsage_Good_Accumulates(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementUsage("agent-1", 1000, 1) - _ = s.IncrementUsage("agent-1", 2000, 1) - _ = s.IncrementUsage("agent-1", 3000, 0) - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(6000), u.TokensUsed) - assert.Equal(t, 2, u.JobsStarted) - assert.Equal(t, 2, u.ActiveJobs) -} - -// --- DecrementActiveJobs --- - -func TestSQLiteStore_DecrementActiveJobs_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementUsage("agent-1", 0, 2) - _ = s.DecrementActiveJobs("agent-1") - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, 1, u.ActiveJobs) -} - -func TestSQLiteStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.DecrementActiveJobs("agent-1") // no usage record yet - _ = s.IncrementUsage("agent-1", 0, 0) - _ = s.DecrementActiveJobs("agent-1") // should stay at 0 - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- ReturnTokens --- - -func TestSQLiteStore_ReturnTokens_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementUsage("agent-1", 10000, 0) - err := s.ReturnTokens("agent-1", 5000) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(5000), u.TokensUsed) -} - -func TestSQLiteStore_ReturnTokens_Good_FloorAtZero(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementUsage("agent-1", 1000, 0) - _ = s.ReturnTokens("agent-1", 5000) // more than used - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) -} - -func TestSQLiteStore_ReturnTokens_Good_NoRecord(t *testing.T) { - s := newTestSQLiteStore(t) - - // Return tokens for agent with no usage record -- should create one at 0 - err := s.ReturnTokens("agent-1", 500) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) -} - -// --- ResetUsage --- - -func TestSQLiteStore_ResetUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementUsage("agent-1", 50000, 5) - - err := s.ResetUsage("agent-1") - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(0), u.TokensUsed) - assert.Equal(t, 0, u.JobsStarted) - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- ModelQuota --- - -func TestSQLiteStore_GetModelQuota_Bad_NotFound(t *testing.T) { - s := newTestSQLiteStore(t) - _, err := s.GetModelQuota("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "model quota not found") -} - -func TestSQLiteStore_SetGetModelQuota_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - q := &ModelQuota{ - Model: "claude-opus-4-6", - DailyTokenBudget: 500000, - HourlyRateLimit: 100, - CostCeiling: 10000, - } - err := s.SetModelQuota(q) - require.NoError(t, err) - - got, err := s.GetModelQuota("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, q.Model, got.Model) - assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget) - assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit) - assert.Equal(t, q.CostCeiling, got.CostCeiling) -} - -// --- ModelUsage --- - -func TestSQLiteStore_ModelUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.IncrementModelUsage("claude-sonnet", 10000) - _ = s.IncrementModelUsage("claude-sonnet", 5000) - - usage, err := s.GetModelUsage("claude-sonnet") - require.NoError(t, err) - assert.Equal(t, int64(15000), usage) -} - -func TestSQLiteStore_GetModelUsage_Good_Default(t *testing.T) { - s := newTestSQLiteStore(t) - - usage, err := s.GetModelUsage("unknown-model") - require.NoError(t, err) - assert.Equal(t, int64(0), usage) -} - -// --- Persistence: close and reopen --- - -func TestSQLiteStore_Persistence_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "persist.db") - - // Phase 1: write data - s1, err := NewSQLiteStore(dbPath) - require.NoError(t, err) - - _ = s1.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - MaxJobDuration: 15 * time.Minute, - }) - _ = s1.IncrementUsage("agent-1", 25000, 3) - _ = s1.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000}) - _ = s1.IncrementModelUsage("claude-opus-4-6", 42000) - require.NoError(t, s1.Close()) - - // Phase 2: reopen and verify - s2, err := NewSQLiteStore(dbPath) - require.NoError(t, err) - defer func() { _ = s2.Close() }() - - a, err := s2.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(100000), a.DailyTokenLimit) - assert.Equal(t, 15*time.Minute, a.MaxJobDuration) - - u, err := s2.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(25000), u.TokensUsed) - assert.Equal(t, 3, u.JobsStarted) - assert.Equal(t, 3, u.ActiveJobs) - - q, err := s2.GetModelQuota("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, int64(500000), q.DailyTokenBudget) - - mu, err := s2.GetModelUsage("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, int64(42000), mu) -} - -// --- Concurrent access --- - -func TestSQLiteStore_ConcurrentIncrementUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - const goroutines = 10 - const tokensEach = 1000 - - var wg sync.WaitGroup - wg.Add(goroutines) - for range goroutines { - go func() { - defer wg.Done() - err := s.IncrementUsage("agent-1", tokensEach, 1) - assert.NoError(t, err) - }() - } - wg.Wait() - - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed) - assert.Equal(t, goroutines, u.JobsStarted) - assert.Equal(t, goroutines, u.ActiveJobs) -} - -func TestSQLiteStore_ConcurrentModelUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - const goroutines = 10 - const tokensEach int64 = 500 - - var wg sync.WaitGroup - wg.Add(goroutines) - for range goroutines { - go func() { - defer wg.Done() - err := s.IncrementModelUsage("claude-opus-4-6", tokensEach) - assert.NoError(t, err) - }() - } - wg.Wait() - - usage, err := s.GetModelUsage("claude-opus-4-6") - require.NoError(t, err) - assert.Equal(t, goroutines*tokensEach, usage) -} - -func TestSQLiteStore_ConcurrentMixed_Good(t *testing.T) { - s := newTestSQLiteStore(t) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 1000000, - DailyJobLimit: 100, - ConcurrentJobs: 50, - }) - - const goroutines = 10 - var wg sync.WaitGroup - wg.Add(goroutines * 3) - - // Increment usage - for range goroutines { - go func() { - defer wg.Done() - _ = s.IncrementUsage("agent-1", 100, 1) - }() - } - - // Decrement active jobs - for range goroutines { - go func() { - defer wg.Done() - _ = s.DecrementActiveJobs("agent-1") - }() - } - - // Return tokens - for range goroutines { - go func() { - defer wg.Done() - _ = s.ReturnTokens("agent-1", 10) - }() - } - - wg.Wait() - - // Just verify no panics and data is consistent - u, err := s.GetUsage("agent-1") - require.NoError(t, err) - assert.GreaterOrEqual(t, u.TokensUsed, int64(0)) - assert.GreaterOrEqual(t, u.ActiveJobs, 0) -} - -// --- AllowanceService integration via SQLiteStore --- - -func TestSQLiteStore_AllowanceServiceCheck_Good(t *testing.T) { - s := newTestSQLiteStore(t) - svc := NewAllowanceService(s) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - }) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceOK, result.Status) -} - -func TestSQLiteStore_AllowanceServiceRecordUsage_Good(t *testing.T) { - s := newTestSQLiteStore(t) - svc := NewAllowanceService(s) - - _ = s.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - }) - - // Start job - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - // Complete job - err = svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 1000, - TokensOut: 500, - Event: QuotaEventJobCompleted, - }) - require.NoError(t, err) - - u, _ := s.GetUsage("agent-1") - assert.Equal(t, int64(1500), u.TokensUsed) - assert.Equal(t, 0, u.ActiveJobs) -} - -// --- Config-based factory --- - -func TestNewAllowanceStoreFromConfig_Good_Memory(t *testing.T) { - cfg := AllowanceConfig{StoreBackend: "memory"} - s, err := NewAllowanceStoreFromConfig(cfg) - require.NoError(t, err) - _, ok := s.(*MemoryStore) - assert.True(t, ok, "expected MemoryStore") -} - -func TestNewAllowanceStoreFromConfig_Good_Default(t *testing.T) { - cfg := AllowanceConfig{} // empty defaults to memory - s, err := NewAllowanceStoreFromConfig(cfg) - require.NoError(t, err) - _, ok := s.(*MemoryStore) - assert.True(t, ok, "expected MemoryStore for empty config") -} - -func TestNewAllowanceStoreFromConfig_Good_SQLite(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "factory.db") - cfg := AllowanceConfig{ - StoreBackend: "sqlite", - StorePath: dbPath, - } - s, err := NewAllowanceStoreFromConfig(cfg) - require.NoError(t, err) - ss, ok := s.(*SQLiteStore) - assert.True(t, ok, "expected SQLiteStore") - _ = ss.Close() -} - -func TestNewAllowanceStoreFromConfig_Bad_UnknownBackend(t *testing.T) { - cfg := AllowanceConfig{StoreBackend: "cassandra"} - _, err := NewAllowanceStoreFromConfig(cfg) - require.Error(t, err) - assert.Contains(t, err.Error(), "unsupported store backend") -} - -// --- NewSQLiteStore error case --- - -func TestNewSQLiteStore_Bad_InvalidPath(t *testing.T) { - _, err := NewSQLiteStore("/nonexistent/deeply/nested/dir/test.db") - require.Error(t, err) -} diff --git a/pkg/lifecycle/allowance_test.go b/pkg/lifecycle/allowance_test.go deleted file mode 100644 index e7225bb..0000000 --- a/pkg/lifecycle/allowance_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package lifecycle - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- MemoryStore tests --- - -func TestMemoryStore_SetGetAllowance_Good(t *testing.T) { - store := NewMemoryStore() - a := &AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - MaxJobDuration: 30 * time.Minute, - ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, - } - - err := store.SetAllowance(a) - require.NoError(t, err) - - got, err := store.GetAllowance("agent-1") - require.NoError(t, err) - assert.Equal(t, a.AgentID, got.AgentID) - assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit) - assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit) - assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs) - assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist) -} - -func TestMemoryStore_GetAllowance_Bad_NotFound(t *testing.T) { - store := NewMemoryStore() - _, err := store.GetAllowance("nonexistent") - require.Error(t, err) -} - -func TestMemoryStore_IncrementUsage_Good(t *testing.T) { - store := NewMemoryStore() - - err := store.IncrementUsage("agent-1", 5000, 1) - require.NoError(t, err) - - usage, err := store.GetUsage("agent-1") - require.NoError(t, err) - assert.Equal(t, int64(5000), usage.TokensUsed) - assert.Equal(t, 1, usage.JobsStarted) - assert.Equal(t, 1, usage.ActiveJobs) -} - -func TestMemoryStore_DecrementActiveJobs_Good(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("agent-1", 0, 2) - _ = store.DecrementActiveJobs("agent-1") - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 1, usage.ActiveJobs) -} - -func TestMemoryStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) { - store := NewMemoryStore() - - _ = store.DecrementActiveJobs("agent-1") // no-op, no usage record - _ = store.IncrementUsage("agent-1", 0, 0) - _ = store.DecrementActiveJobs("agent-1") // should stay at 0 - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 0, usage.ActiveJobs) -} - -func TestMemoryStore_ReturnTokens_Good(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("agent-1", 10000, 0) - err := store.ReturnTokens("agent-1", 5000) - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(5000), usage.TokensUsed) -} - -func TestMemoryStore_ReturnTokens_Good_FloorAtZero(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("agent-1", 1000, 0) - _ = store.ReturnTokens("agent-1", 5000) // more than used - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(0), usage.TokensUsed) -} - -func TestMemoryStore_ResetUsage_Good(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementUsage("agent-1", 50000, 5) - err := store.ResetUsage("agent-1") - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.JobsStarted) - assert.Equal(t, 0, usage.ActiveJobs) -} - -func TestMemoryStore_ModelUsage_Good(t *testing.T) { - store := NewMemoryStore() - - _ = store.IncrementModelUsage("claude-sonnet", 10000) - _ = store.IncrementModelUsage("claude-sonnet", 5000) - - usage, err := store.GetModelUsage("claude-sonnet") - require.NoError(t, err) - assert.Equal(t, int64(15000), usage) -} - -// --- AllowanceService.Check tests --- - -func TestAllowanceServiceCheck_Good(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 2, - }) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceOK, result.Status) - assert.Equal(t, int64(100000), result.RemainingTokens) - assert.Equal(t, 10, result.RemainingJobs) -} - -func TestAllowanceServiceCheck_Good_Warning(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - }) - _ = store.IncrementUsage("agent-1", 85000, 0) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceWarning, result.Status) - assert.Equal(t, int64(15000), result.RemainingTokens) -} - -func TestAllowanceServiceCheck_Bad_TokenLimitExceeded(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100000, - }) - _ = store.IncrementUsage("agent-1", 100001, 0) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Equal(t, AllowanceExceeded, result.Status) - assert.Contains(t, result.Reason, "daily token limit") -} - -func TestAllowanceServiceCheck_Bad_JobLimitExceeded(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyJobLimit: 5, - }) - _ = store.IncrementUsage("agent-1", 0, 5) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Contains(t, result.Reason, "daily job limit") -} - -func TestAllowanceServiceCheck_Bad_ConcurrentLimitReached(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ConcurrentJobs: 1, - }) - _ = store.IncrementUsage("agent-1", 0, 1) // 1 active job - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Contains(t, result.Reason, "concurrent job limit") -} - -func TestAllowanceServiceCheck_Bad_ModelNotInAllowlist(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, - }) - - result, err := svc.Check("agent-1", "claude-opus-4-6") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Contains(t, result.Reason, "model not in allowlist") -} - -func TestAllowanceServiceCheck_Good_ModelInAllowlist(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ModelAllowlist: []string{"claude-sonnet-4-5-20250929", "claude-haiku-4-5-20251001"}, - }) - - result, err := svc.Check("agent-1", "claude-sonnet-4-5-20250929") - require.NoError(t, err) - assert.True(t, result.Allowed) -} - -func TestAllowanceServiceCheck_Good_EmptyModelSkipsCheck(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, - }) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) -} - -func TestAllowanceServiceCheck_Bad_GlobalModelBudgetExceeded(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - }) - store.SetModelQuota(&ModelQuota{ - Model: "claude-opus-4-6", - DailyTokenBudget: 500000, - }) - _ = store.IncrementModelUsage("claude-opus-4-6", 500001) - - result, err := svc.Check("agent-1", "claude-opus-4-6") - require.NoError(t, err) - assert.False(t, result.Allowed) - assert.Contains(t, result.Reason, "global model token budget") -} - -func TestAllowanceServiceCheck_Bad_NoAllowance(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _, err := svc.Check("unknown-agent", "") - require.Error(t, err) -} - -// --- AllowanceService.RecordUsage tests --- - -func TestAllowanceServiceRecordUsage_Good_JobStarted(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 1, usage.JobsStarted) - assert.Equal(t, 1, usage.ActiveJobs) - assert.Equal(t, int64(0), usage.TokensUsed) -} - -func TestAllowanceServiceRecordUsage_Good_JobCompleted(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - // Start a job first - _ = svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 1000, - TokensOut: 500, - Event: QuotaEventJobCompleted, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(1500), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) - - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(1500), modelUsage) -} - -func TestAllowanceServiceRecordUsage_Good_JobFailed_ReturnsHalf(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - }) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 1000, - TokensOut: 1000, - Event: QuotaEventJobFailed, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - // 2000 tokens used, 1000 returned (50%) = 1000 net - assert.Equal(t, int64(1000), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) - - // Model sees net usage (2000 - 1000 = 1000) - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(1000), modelUsage) -} - -func TestAllowanceServiceRecordUsage_Good_JobCancelled_ReturnsAll(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.IncrementUsage("agent-1", 5000, 1) // simulate pre-existing usage - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - TokensIn: 500, - TokensOut: 500, - Event: QuotaEventJobCancelled, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - // 5000 pre-existing - 1000 returned = 4000 - assert.Equal(t, int64(4000), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) -} - -// --- AllowanceService.ResetAgent tests --- - -func TestAllowanceServiceResetAgent_Good(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.IncrementUsage("agent-1", 50000, 5) - - err := svc.ResetAgent("agent-1") - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.JobsStarted) -} - -// --- startOfDay helper test --- - -func TestStartOfDay_Good(t *testing.T) { - input := time.Date(2026, 2, 10, 15, 30, 45, 0, time.UTC) - expected := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) - assert.Equal(t, expected, startOfDay(input)) -} - -// --- AllowanceStatus tests --- - -func TestAllowanceStatus_Good_Values(t *testing.T) { - assert.Equal(t, AllowanceStatus("ok"), AllowanceOK) - assert.Equal(t, AllowanceStatus("warning"), AllowanceWarning) - assert.Equal(t, AllowanceStatus("exceeded"), AllowanceExceeded) -} diff --git a/pkg/lifecycle/brain.go b/pkg/lifecycle/brain.go deleted file mode 100644 index 4ed9de6..0000000 --- a/pkg/lifecycle/brain.go +++ /dev/null @@ -1,215 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - - "forge.lthn.ai/core/go-log" -) - -// MemoryType represents the classification of a brain memory. -type MemoryType string - -const ( - MemoryFact MemoryType = "fact" - MemoryDecision MemoryType = "decision" - MemoryPattern MemoryType = "pattern" - MemoryContext MemoryType = "context" - MemoryProcedure MemoryType = "procedure" -) - -// Memory represents a single memory entry from the OpenBrain API. -type Memory struct { - ID string `json:"id"` - AgentID string `json:"agent_id,omitempty"` - Type string `json:"type"` - Content string `json:"content"` - Tags []string `json:"tags,omitempty"` - Project string `json:"project,omitempty"` - Confidence float64 `json:"confidence,omitempty"` - SupersedesID string `json:"supersedes_id,omitempty"` - ExpiresAt string `json:"expires_at,omitempty"` - CreatedAt string `json:"created_at,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` -} - -// RememberRequest is the payload for storing a new memory. -type RememberRequest struct { - Content string `json:"content"` - Type string `json:"type"` - Project string `json:"project,omitempty"` - AgentID string `json:"agent_id,omitempty"` - Tags []string `json:"tags,omitempty"` - Confidence float64 `json:"confidence,omitempty"` - SupersedesID string `json:"supersedes_id,omitempty"` - Source string `json:"source,omitempty"` -} - -// RememberResponse is returned after storing a memory. -type RememberResponse struct { - ID string `json:"id"` - Type string `json:"type"` - Project string `json:"project"` - CreatedAt string `json:"created_at"` -} - -// RecallRequest is the payload for semantic search. -type RecallRequest struct { - Query string `json:"query"` - TopK int `json:"top_k,omitempty"` - Project string `json:"project,omitempty"` - Type string `json:"type,omitempty"` - AgentID string `json:"agent_id,omitempty"` - MinConfidence *float64 `json:"min_confidence,omitempty"` -} - -// RecallResponse is returned from a semantic search. -type RecallResponse struct { - Memories []Memory `json:"memories"` - Scores map[string]float64 `json:"scores"` -} - -// Remember stores a memory via POST /v1/brain/remember. -func (c *Client) Remember(ctx context.Context, req RememberRequest) (*RememberResponse, error) { - const op = "agentic.Client.Remember" - - if req.Content == "" { - return nil, log.E(op, "content is required", nil) - } - if req.Type == "" { - return nil, log.E(op, "type is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/brain/remember" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result RememberResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} - -// Recall performs semantic search via POST /v1/brain/recall. -func (c *Client) Recall(ctx context.Context, req RecallRequest) (*RecallResponse, error) { - const op = "agentic.Client.Recall" - - if req.Query == "" { - return nil, log.E(op, "query is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/brain/recall" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result RecallResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} - -// Forget removes a memory via DELETE /v1/brain/forget/{id}. -func (c *Client) Forget(ctx context.Context, id string) error { - const op = "agentic.Client.Forget" - - if id == "" { - return log.E(op, "memory ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/brain/forget/%s", c.BaseURL, url.PathEscape(id)) - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil) - if err != nil { - return log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return log.E(op, "API error", err) - } - - return nil -} - -// EnsureCollection ensures the Qdrant collection exists via POST /v1/brain/collections. -func (c *Client) EnsureCollection(ctx context.Context) error { - const op = "agentic.Client.EnsureCollection" - - endpoint := c.BaseURL + "/v1/brain/collections" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil) - if err != nil { - return log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return log.E(op, "API error", err) - } - - return nil -} diff --git a/pkg/lifecycle/brain_test.go b/pkg/lifecycle/brain_test.go deleted file mode 100644 index 94fde90..0000000 --- a/pkg/lifecycle/brain_test.go +++ /dev/null @@ -1,234 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestClient_Remember_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/v1/brain/remember", r.URL.Path) - assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - - var req RememberRequest - err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) - assert.Equal(t, "Go uses structural typing", req.Content) - assert.Equal(t, "fact", req.Type) - assert.Equal(t, "go-agentic", req.Project) - assert.Equal(t, []string{"go", "typing"}, req.Tags) - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(RememberResponse{ - ID: "mem-abc-123", - Type: "fact", - Project: "go-agentic", - CreatedAt: "2026-03-03T12:00:00+00:00", - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.Remember(context.Background(), RememberRequest{ - Content: "Go uses structural typing", - Type: "fact", - Project: "go-agentic", - Tags: []string{"go", "typing"}, - }) - - require.NoError(t, err) - assert.Equal(t, "mem-abc-123", result.ID) - assert.Equal(t, "fact", result.Type) - assert.Equal(t, "go-agentic", result.Project) -} - -func TestClient_Remember_Bad_EmptyContent(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - result, err := client.Remember(context.Background(), RememberRequest{ - Type: "fact", - }) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "content is required") -} - -func TestClient_Remember_Bad_EmptyType(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - result, err := client.Remember(context.Background(), RememberRequest{ - Content: "something", - }) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "type is required") -} - -func TestClient_Remember_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnprocessableEntity) - _ = json.NewEncoder(w).Encode(APIError{Message: "validation failed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.Remember(context.Background(), RememberRequest{ - Content: "test", - Type: "fact", - }) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "validation failed") -} - -func TestClient_Recall_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/v1/brain/recall", r.URL.Path) - - var req RecallRequest - err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) - assert.Equal(t, "how does typing work in Go", req.Query) - assert.Equal(t, 5, req.TopK) - assert.Equal(t, "go-agentic", req.Project) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(RecallResponse{ - Memories: []Memory{ - { - ID: "mem-abc-123", - Type: "fact", - Content: "Go uses structural typing", - Project: "go-agentic", - Confidence: 0.95, - }, - }, - Scores: map[string]float64{ - "mem-abc-123": 0.87, - }, - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.Recall(context.Background(), RecallRequest{ - Query: "how does typing work in Go", - TopK: 5, - Project: "go-agentic", - }) - - require.NoError(t, err) - assert.Len(t, result.Memories, 1) - assert.Equal(t, "mem-abc-123", result.Memories[0].ID) - assert.Equal(t, "Go uses structural typing", result.Memories[0].Content) - assert.InDelta(t, 0.87, result.Scores["mem-abc-123"], 0.001) -} - -func TestClient_Recall_Good_EmptyResults(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(RecallResponse{ - Memories: []Memory{}, - Scores: map[string]float64{}, - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.Recall(context.Background(), RecallRequest{ - Query: "something obscure", - }) - - require.NoError(t, err) - assert.Empty(t, result.Memories) - assert.Empty(t, result.Scores) -} - -func TestClient_Recall_Bad_EmptyQuery(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - result, err := client.Recall(context.Background(), RecallRequest{}) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "query is required") -} - -func TestClient_Forget_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodDelete, r.Method) - assert.Equal(t, "/v1/brain/forget/mem-abc-123", r.URL.Path) - assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]bool{"deleted": true}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.Forget(context.Background(), "mem-abc-123") - - assert.NoError(t, err) -} - -func TestClient_Forget_Bad_EmptyID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - err := client.Forget(context.Background(), "") - - assert.Error(t, err) - assert.Contains(t, err.Error(), "memory ID is required") -} - -func TestClient_Forget_Bad_NotFound(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - _ = json.NewEncoder(w).Encode(APIError{Message: "memory not found"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.Forget(context.Background(), "nonexistent") - - assert.Error(t, err) - assert.Contains(t, err.Error(), "memory not found") -} - -func TestClient_EnsureCollection_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/v1/brain/collections", r.URL.Path) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.EnsureCollection(context.Background()) - - assert.NoError(t, err) -} - -func TestClient_EnsureCollection_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "collection setup failed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.EnsureCollection(context.Background()) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "collection setup failed") -} diff --git a/pkg/lifecycle/client.go b/pkg/lifecycle/client.go deleted file mode 100644 index 9b7b5f3..0000000 --- a/pkg/lifecycle/client.go +++ /dev/null @@ -1,359 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/url" - "strconv" - "strings" - "time" - - "forge.lthn.ai/core/go-log" -) - -// Client is the API client for the core-agentic service. -type Client struct { - // BaseURL is the base URL of the API server. - BaseURL string - // Token is the authentication token. - Token string - // HTTPClient is the HTTP client used for requests. - HTTPClient *http.Client - // AgentID is the identifier for this agent when claiming tasks. - AgentID string -} - -// NewClient creates a new agentic API client with the given base URL and token. -func NewClient(baseURL, token string) *Client { - return &Client{ - BaseURL: strings.TrimSuffix(baseURL, "/"), - Token: token, - HTTPClient: &http.Client{ - Timeout: 30 * time.Second, - }, - } -} - -// NewClientFromConfig creates a new client from a Config struct. -func NewClientFromConfig(cfg *Config) *Client { - client := NewClient(cfg.BaseURL, cfg.Token) - client.AgentID = cfg.AgentID - return client -} - -// ListTasks retrieves a list of tasks matching the given options. -func (c *Client) ListTasks(ctx context.Context, opts ListOptions) ([]Task, error) { - const op = "agentic.Client.ListTasks" - - // Build query parameters - params := url.Values{} - if opts.Status != "" { - params.Set("status", string(opts.Status)) - } - if opts.Priority != "" { - params.Set("priority", string(opts.Priority)) - } - if opts.Project != "" { - params.Set("project", opts.Project) - } - if opts.ClaimedBy != "" { - params.Set("claimed_by", opts.ClaimedBy) - } - if opts.Limit > 0 { - params.Set("limit", strconv.Itoa(opts.Limit)) - } - if len(opts.Labels) > 0 { - params.Set("labels", strings.Join(opts.Labels, ",")) - } - - endpoint := c.BaseURL + "/api/tasks" - if len(params) > 0 { - endpoint += "?" + params.Encode() - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var tasks []Task - if err := json.NewDecoder(resp.Body).Decode(&tasks); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return tasks, nil -} - -// GetTask retrieves a single task by its ID. -func (c *Client) GetTask(ctx context.Context, id string) (*Task, error) { - const op = "agentic.Client.GetTask" - - if id == "" { - return nil, log.E(op, "task ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/api/tasks/%s", c.BaseURL, url.PathEscape(id)) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var task Task - if err := json.NewDecoder(resp.Body).Decode(&task); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &task, nil -} - -// ClaimTask claims a task for the current agent. -func (c *Client) ClaimTask(ctx context.Context, id string) (*Task, error) { - const op = "agentic.Client.ClaimTask" - - if id == "" { - return nil, log.E(op, "task ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/api/tasks/%s/claim", c.BaseURL, url.PathEscape(id)) - - // Include agent ID in the claim request if available - var body io.Reader - if c.AgentID != "" { - data, _ := json.Marshal(map[string]string{"agent_id": c.AgentID}) - body = bytes.NewReader(data) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, body) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - if body != nil { - req.Header.Set("Content-Type", "application/json") - } - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - // Read body once to allow multiple decode attempts - bodyData, err := io.ReadAll(resp.Body) - if err != nil { - return nil, log.E(op, "failed to read response", err) - } - - // Try decoding as ClaimResponse first - var result ClaimResponse - if err := json.Unmarshal(bodyData, &result); err == nil && result.Task != nil { - return result.Task, nil - } - - // Try decoding as just a Task for simpler API responses - var task Task - if err := json.Unmarshal(bodyData, &task); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &task, nil -} - -// UpdateTask updates a task with new status, progress, or notes. -func (c *Client) UpdateTask(ctx context.Context, id string, update TaskUpdate) error { - const op = "agentic.Client.UpdateTask" - - if id == "" { - return log.E(op, "task ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/api/tasks/%s", c.BaseURL, url.PathEscape(id)) - - data, err := json.Marshal(update) - if err != nil { - return log.E(op, "failed to marshal update", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return log.E(op, "API error", err) - } - - return nil -} - -// CompleteTask marks a task as completed with the given result. -func (c *Client) CompleteTask(ctx context.Context, id string, result TaskResult) error { - const op = "agentic.Client.CompleteTask" - - if id == "" { - return log.E(op, "task ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/api/tasks/%s/complete", c.BaseURL, url.PathEscape(id)) - - data, err := json.Marshal(result) - if err != nil { - return log.E(op, "failed to marshal result", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return log.E(op, "API error", err) - } - - return nil -} - -// setHeaders adds common headers to the request. -func (c *Client) setHeaders(req *http.Request) { - req.Header.Set("Authorization", "Bearer "+c.Token) - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "core-agentic-client/1.0") -} - -// checkResponse checks if the response indicates an error. -func (c *Client) checkResponse(resp *http.Response) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - return nil - } - - body, _ := io.ReadAll(resp.Body) - - // Try to parse as APIError - var apiErr APIError - if err := json.Unmarshal(body, &apiErr); err == nil && apiErr.Message != "" { - apiErr.Code = resp.StatusCode - return &apiErr - } - - // Return generic error - return &APIError{ - Code: resp.StatusCode, - Message: fmt.Sprintf("HTTP %d: %s", resp.StatusCode, http.StatusText(resp.StatusCode)), - Details: string(body), - } -} - -// CreateTask creates a new task via POST /api/tasks. -func (c *Client) CreateTask(ctx context.Context, task Task) (*Task, error) { - const op = "agentic.Client.CreateTask" - - data, err := json.Marshal(task) - if err != nil { - return nil, log.E(op, "failed to marshal task", err) - } - - endpoint := c.BaseURL + "/api/tasks" - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var created Task - if err := json.NewDecoder(resp.Body).Decode(&created); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &created, nil -} - -// Ping tests the connection to the API server. -func (c *Client) Ping(ctx context.Context) error { - const op = "agentic.Client.Ping" - - endpoint := c.BaseURL + "/v1/health" - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return log.E(op, "failed to create request", err) - } - - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= 400 { - return log.E(op, fmt.Sprintf("server returned status %d", resp.StatusCode), nil) - } - - return nil -} diff --git a/pkg/lifecycle/client_test.go b/pkg/lifecycle/client_test.go deleted file mode 100644 index a9146fc..0000000 --- a/pkg/lifecycle/client_test.go +++ /dev/null @@ -1,356 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Test fixtures -var testTask = Task{ - ID: "task-123", - Title: "Implement feature X", - Description: "Add the new feature X to the system", - Priority: PriorityHigh, - Status: StatusPending, - Labels: []string{"feature", "backend"}, - Files: []string{"pkg/feature/feature.go"}, - CreatedAt: time.Now().Add(-24 * time.Hour), - Project: "core", -} - -var testTasks = []Task{ - testTask, - { - ID: "task-456", - Title: "Fix bug Y", - Description: "Fix the bug in component Y", - Priority: PriorityCritical, - Status: StatusPending, - Labels: []string{"bug", "urgent"}, - CreatedAt: time.Now().Add(-2 * time.Hour), - Project: "core", - }, -} - -func TestNewClient_Good(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - - assert.Equal(t, "https://api.example.com", client.BaseURL) - assert.Equal(t, "test-token", client.Token) - assert.NotNil(t, client.HTTPClient) -} - -func TestNewClient_Good_TrailingSlash(t *testing.T) { - client := NewClient("https://api.example.com/", "test-token") - - assert.Equal(t, "https://api.example.com", client.BaseURL) -} - -func TestNewClientFromConfig_Good(t *testing.T) { - cfg := &Config{ - BaseURL: "https://api.example.com", - Token: "config-token", - AgentID: "agent-001", - } - - client := NewClientFromConfig(cfg) - - assert.Equal(t, "https://api.example.com", client.BaseURL) - assert.Equal(t, "config-token", client.Token) - assert.Equal(t, "agent-001", client.AgentID) -} - -func TestClient_ListTasks_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodGet, r.Method) - assert.Equal(t, "/api/tasks", r.URL.Path) - assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(testTasks) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - tasks, err := client.ListTasks(context.Background(), ListOptions{}) - - require.NoError(t, err) - assert.Len(t, tasks, 2) - assert.Equal(t, "task-123", tasks[0].ID) - assert.Equal(t, "task-456", tasks[1].ID) -} - -func TestClient_ListTasks_Good_WithFilters(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - assert.Equal(t, "pending", query.Get("status")) - assert.Equal(t, "high", query.Get("priority")) - assert.Equal(t, "core", query.Get("project")) - assert.Equal(t, "10", query.Get("limit")) - assert.Equal(t, "bug,urgent", query.Get("labels")) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode([]Task{testTask}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - opts := ListOptions{ - Status: StatusPending, - Priority: PriorityHigh, - Project: "core", - Limit: 10, - Labels: []string{"bug", "urgent"}, - } - - tasks, err := client.ListTasks(context.Background(), opts) - - require.NoError(t, err) - assert.Len(t, tasks, 1) -} - -func TestClient_ListTasks_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "internal error"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - tasks, err := client.ListTasks(context.Background(), ListOptions{}) - - assert.Error(t, err) - assert.Nil(t, tasks) - assert.Contains(t, err.Error(), "internal error") -} - -func TestClient_GetTask_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodGet, r.Method) - assert.Equal(t, "/api/tasks/task-123", r.URL.Path) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(testTask) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task, err := client.GetTask(context.Background(), "task-123") - - require.NoError(t, err) - assert.Equal(t, "task-123", task.ID) - assert.Equal(t, "Implement feature X", task.Title) - assert.Equal(t, PriorityHigh, task.Priority) -} - -func TestClient_GetTask_Bad_EmptyID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - task, err := client.GetTask(context.Background(), "") - - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "task ID is required") -} - -func TestClient_GetTask_Bad_NotFound(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - _ = json.NewEncoder(w).Encode(APIError{Message: "task not found"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task, err := client.GetTask(context.Background(), "nonexistent") - - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "task not found") -} - -func TestClient_ClaimTask_Good(t *testing.T) { - claimedTask := testTask - claimedTask.Status = StatusInProgress - claimedTask.ClaimedBy = "agent-001" - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/api/tasks/task-123/claim", r.URL.Path) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - client.AgentID = "agent-001" - task, err := client.ClaimTask(context.Background(), "task-123") - - require.NoError(t, err) - assert.Equal(t, StatusInProgress, task.Status) - assert.Equal(t, "agent-001", task.ClaimedBy) -} - -func TestClient_ClaimTask_Good_SimpleResponse(t *testing.T) { - // Some APIs might return just the task without wrapping - claimedTask := testTask - claimedTask.Status = StatusInProgress - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(claimedTask) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task, err := client.ClaimTask(context.Background(), "task-123") - - require.NoError(t, err) - assert.Equal(t, "task-123", task.ID) -} - -func TestClient_ClaimTask_Bad_EmptyID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - task, err := client.ClaimTask(context.Background(), "") - - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "task ID is required") -} - -func TestClient_ClaimTask_Bad_AlreadyClaimed(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(APIError{Message: "task already claimed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task, err := client.ClaimTask(context.Background(), "task-123") - - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "task already claimed") -} - -func TestClient_UpdateTask_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPatch, r.Method) - assert.Equal(t, "/api/tasks/task-123", r.URL.Path) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - - var update TaskUpdate - err := json.NewDecoder(r.Body).Decode(&update) - require.NoError(t, err) - assert.Equal(t, StatusInProgress, update.Status) - assert.Equal(t, 50, update.Progress) - - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.UpdateTask(context.Background(), "task-123", TaskUpdate{ - Status: StatusInProgress, - Progress: 50, - Notes: "Making progress", - }) - - assert.NoError(t, err) -} - -func TestClient_UpdateTask_Bad_EmptyID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - err := client.UpdateTask(context.Background(), "", TaskUpdate{}) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "task ID is required") -} - -func TestClient_CompleteTask_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/api/tasks/task-123/complete", r.URL.Path) - - var result TaskResult - err := json.NewDecoder(r.Body).Decode(&result) - require.NoError(t, err) - assert.True(t, result.Success) - assert.Equal(t, "Feature implemented", result.Output) - - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.CompleteTask(context.Background(), "task-123", TaskResult{ - Success: true, - Output: "Feature implemented", - Artifacts: []string{"pkg/feature/feature.go"}, - }) - - assert.NoError(t, err) -} - -func TestClient_CompleteTask_Bad_EmptyID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - err := client.CompleteTask(context.Background(), "", TaskResult{}) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "task ID is required") -} - -func TestClient_Ping_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/health", r.URL.Path) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.Ping(context.Background()) - - assert.NoError(t, err) -} - -func TestClient_Ping_Bad_ServerDown(t *testing.T) { - client := NewClient("http://localhost:99999", "test-token") - client.HTTPClient.Timeout = 100 * time.Millisecond - - err := client.Ping(context.Background()) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "request failed") -} - -func TestAPIError_Error_Good(t *testing.T) { - err := &APIError{ - Code: 404, - Message: "task not found", - } - - assert.Equal(t, "task not found", err.Error()) - - err.Details = "task-123 does not exist" - assert.Equal(t, "task not found: task-123 does not exist", err.Error()) -} - -func TestTaskStatus_Good(t *testing.T) { - assert.Equal(t, TaskStatus("pending"), StatusPending) - assert.Equal(t, TaskStatus("in_progress"), StatusInProgress) - assert.Equal(t, TaskStatus("completed"), StatusCompleted) - assert.Equal(t, TaskStatus("blocked"), StatusBlocked) -} - -func TestTaskPriority_Good(t *testing.T) { - assert.Equal(t, TaskPriority("critical"), PriorityCritical) - assert.Equal(t, TaskPriority("high"), PriorityHigh) - assert.Equal(t, TaskPriority("medium"), PriorityMedium) - assert.Equal(t, TaskPriority("low"), PriorityLow) -} diff --git a/pkg/lifecycle/completion.go b/pkg/lifecycle/completion.go deleted file mode 100644 index 8dd2da7..0000000 --- a/pkg/lifecycle/completion.go +++ /dev/null @@ -1,338 +0,0 @@ -// Package agentic provides AI collaboration features for task management. -package lifecycle - -import ( - "bytes" - "context" - "fmt" - "os/exec" - "strings" - - "forge.lthn.ai/core/go-log" -) - -// PROptions contains options for creating a pull request. -type PROptions struct { - // Title is the PR title. - Title string `json:"title"` - // Body is the PR description. - Body string `json:"body"` - // Draft marks the PR as a draft. - Draft bool `json:"draft"` - // Labels are labels to add to the PR. - Labels []string `json:"labels"` - // Base is the base branch (defaults to main). - Base string `json:"base"` -} - -// AutoCommit creates a git commit with a task reference. -// The commit message follows the format: -// -// feat(scope): description -// -// Task: #123 -// Co-Authored-By: Claude -func AutoCommit(ctx context.Context, task *Task, dir string, message string) error { - const op = "agentic.AutoCommit" - - if task == nil { - return log.E(op, "task is required", nil) - } - - if message == "" { - return log.E(op, "commit message is required", nil) - } - - // Build full commit message - fullMessage := buildCommitMessage(task, message) - - // Stage all changes - if _, err := runGitCommandCtx(ctx, dir, "add", "-A"); err != nil { - return log.E(op, "failed to stage changes", err) - } - - // Create commit - if _, err := runGitCommandCtx(ctx, dir, "commit", "-m", fullMessage); err != nil { - return log.E(op, "failed to create commit", err) - } - - return nil -} - -// buildCommitMessage formats a commit message with task reference. -func buildCommitMessage(task *Task, message string) string { - var sb strings.Builder - - // Write the main message - sb.WriteString(message) - sb.WriteString("\n\n") - - // Add task reference - sb.WriteString("Task: #") - sb.WriteString(task.ID) - sb.WriteString("\n") - - // Add co-author - sb.WriteString("Co-Authored-By: Claude \n") - - return sb.String() -} - -// CreatePR creates a pull request using the gh CLI. -func CreatePR(ctx context.Context, task *Task, dir string, opts PROptions) (string, error) { - const op = "agentic.CreatePR" - - if task == nil { - return "", log.E(op, "task is required", nil) - } - - // Build title if not provided - title := opts.Title - if title == "" { - title = task.Title - } - - // Build body if not provided - body := opts.Body - if body == "" { - body = buildPRBody(task) - } - - // Build gh command arguments - args := []string{"pr", "create", "--title", title, "--body", body} - - if opts.Draft { - args = append(args, "--draft") - } - - if opts.Base != "" { - args = append(args, "--base", opts.Base) - } - - for _, label := range opts.Labels { - args = append(args, "--label", label) - } - - // Run gh pr create - output, err := runCommandCtx(ctx, dir, "gh", args...) - if err != nil { - return "", log.E(op, "failed to create PR", err) - } - - // Extract PR URL from output - prURL := strings.TrimSpace(output) - - return prURL, nil -} - -// buildPRBody creates a PR body from task details. -func buildPRBody(task *Task) string { - var sb strings.Builder - - sb.WriteString("## Summary\n\n") - sb.WriteString(task.Description) - sb.WriteString("\n\n") - - sb.WriteString("## Task Reference\n\n") - sb.WriteString("- Task ID: #") - sb.WriteString(task.ID) - sb.WriteString("\n") - sb.WriteString("- Priority: ") - sb.WriteString(string(task.Priority)) - sb.WriteString("\n") - - if len(task.Labels) > 0 { - sb.WriteString("- Labels: ") - sb.WriteString(strings.Join(task.Labels, ", ")) - sb.WriteString("\n") - } - - sb.WriteString("\n---\n") - sb.WriteString("Generated with AI assistance\n") - - return sb.String() -} - -// SyncStatus syncs the task status back to the agentic service. -func SyncStatus(ctx context.Context, client *Client, task *Task, update TaskUpdate) error { - const op = "agentic.SyncStatus" - - if client == nil { - return log.E(op, "client is required", nil) - } - - if task == nil { - return log.E(op, "task is required", nil) - } - - return client.UpdateTask(ctx, task.ID, update) -} - -// CommitAndSync commits changes and syncs task status. -func CommitAndSync(ctx context.Context, client *Client, task *Task, dir string, message string, progress int) error { - const op = "agentic.CommitAndSync" - - // Create commit - if err := AutoCommit(ctx, task, dir, message); err != nil { - return log.E(op, "failed to commit", err) - } - - // Sync status if client provided - if client != nil { - update := TaskUpdate{ - Status: StatusInProgress, - Progress: progress, - Notes: "Committed: " + message, - } - - if err := SyncStatus(ctx, client, task, update); err != nil { - // Log but don't fail on sync errors - return log.E(op, "commit succeeded but sync failed", err) - } - } - - return nil -} - -// PushChanges pushes committed changes to the remote. -func PushChanges(ctx context.Context, dir string) error { - const op = "agentic.PushChanges" - - _, err := runGitCommandCtx(ctx, dir, "push") - if err != nil { - return log.E(op, "failed to push changes", err) - } - - return nil -} - -// CreateBranch creates a new branch for the task. -func CreateBranch(ctx context.Context, task *Task, dir string) (string, error) { - const op = "agentic.CreateBranch" - - if task == nil { - return "", log.E(op, "task is required", nil) - } - - // Generate branch name from task - branchName := generateBranchName(task) - - // Create and checkout branch - _, err := runGitCommandCtx(ctx, dir, "checkout", "-b", branchName) - if err != nil { - return "", log.E(op, "failed to create branch", err) - } - - return branchName, nil -} - -// generateBranchName creates a branch name from task details. -func generateBranchName(task *Task) string { - // Determine prefix based on labels - prefix := "feat" - for _, label := range task.Labels { - switch strings.ToLower(label) { - case "bug", "bugfix", "fix": - prefix = "fix" - case "docs", "documentation": - prefix = "docs" - case "refactor": - prefix = "refactor" - case "test", "tests": - prefix = "test" - case "chore": - prefix = "chore" - } - } - - // Sanitize title for branch name - title := strings.ToLower(task.Title) - title = strings.Map(func(r rune) rune { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { - return r - } - if r == ' ' || r == '-' || r == '_' { - return '-' - } - return -1 - }, title) - - // Remove consecutive dashes - for strings.Contains(title, "--") { - title = strings.ReplaceAll(title, "--", "-") - } - title = strings.Trim(title, "-") - - // Truncate if too long - if len(title) > 40 { - title = title[:40] - title = strings.TrimRight(title, "-") - } - - return fmt.Sprintf("%s/%s-%s", prefix, task.ID, title) -} - -// runGitCommandCtx runs a git command with context. -func runGitCommandCtx(ctx context.Context, dir string, args ...string) (string, error) { - return runCommandCtx(ctx, dir, "git", args...) -} - -// runCommandCtx runs an arbitrary command with context. -func runCommandCtx(ctx context.Context, dir string, command string, args ...string) (string, error) { - cmd := exec.CommandContext(ctx, command, args...) - cmd.Dir = dir - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - if stderr.Len() > 0 { - return "", log.E("runCommandCtx", stderr.String(), err) - } - return "", err - } - - return stdout.String(), nil -} - -// GetCurrentBranch returns the current git branch name. -func GetCurrentBranch(ctx context.Context, dir string) (string, error) { - const op = "agentic.GetCurrentBranch" - - output, err := runGitCommandCtx(ctx, dir, "rev-parse", "--abbrev-ref", "HEAD") - if err != nil { - return "", log.E(op, "failed to get current branch", err) - } - - return strings.TrimSpace(output), nil -} - -// HasUncommittedChanges checks if there are uncommitted changes. -func HasUncommittedChanges(ctx context.Context, dir string) (bool, error) { - const op = "agentic.HasUncommittedChanges" - - output, err := runGitCommandCtx(ctx, dir, "status", "--porcelain") - if err != nil { - return false, log.E(op, "failed to get git status", err) - } - - return strings.TrimSpace(output) != "", nil -} - -// GetDiff returns the current diff for staged and unstaged changes. -func GetDiff(ctx context.Context, dir string, staged bool) (string, error) { - const op = "agentic.GetDiff" - - args := []string{"diff"} - if staged { - args = append(args, "--staged") - } - - output, err := runGitCommandCtx(ctx, dir, args...) - if err != nil { - return "", log.E(op, "failed to get diff", err) - } - - return output, nil -} diff --git a/pkg/lifecycle/completion_git_test.go b/pkg/lifecycle/completion_git_test.go deleted file mode 100644 index fd6dc77..0000000 --- a/pkg/lifecycle/completion_git_test.go +++ /dev/null @@ -1,474 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// initGitRepo creates a temporary git repo with an initial commit. -func initGitRepo(t *testing.T) string { - t.Helper() - dir := t.TempDir() - - // Initialise a git repo. - _, err := runCommandCtx(context.Background(), dir, "git", "init") - require.NoError(t, err, "git init should succeed") - - // Configure git identity for commits. - _, err = runCommandCtx(context.Background(), dir, "git", "config", "user.email", "test@example.com") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "config", "user.name", "Test User") - require.NoError(t, err) - - // Create initial commit so HEAD exists. - readmePath := filepath.Join(dir, "README.md") - err = os.WriteFile(readmePath, []byte("# Test Repo\n"), 0644) - require.NoError(t, err) - - _, err = runCommandCtx(context.Background(), dir, "git", "add", "-A") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "initial commit") - require.NoError(t, err) - - return dir -} - -// --- runCommandCtx / runGitCommandCtx tests --- - -func TestRunCommandCtx_Good(t *testing.T) { - output, err := runCommandCtx(context.Background(), "/tmp", "echo", "hello world") - require.NoError(t, err) - assert.Contains(t, output, "hello world") -} - -func TestRunCommandCtx_Bad_NonexistentCommand(t *testing.T) { - _, err := runCommandCtx(context.Background(), "/tmp", "nonexistent-command-xyz") - assert.Error(t, err) -} - -func TestRunCommandCtx_Bad_CommandFails(t *testing.T) { - _, err := runCommandCtx(context.Background(), "/tmp", "false") - assert.Error(t, err) -} - -func TestRunCommandCtx_Bad_StderrIncluded(t *testing.T) { - // git status in a non-git directory should produce stderr. - dir := t.TempDir() - _, err := runCommandCtx(context.Background(), dir, "git", "status") - assert.Error(t, err) -} - -func TestRunGitCommandCtx_Good(t *testing.T) { - dir := initGitRepo(t) - - output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1") - require.NoError(t, err) - assert.Contains(t, output, "initial commit") -} - -// --- GetCurrentBranch tests --- - -func TestGetCurrentBranch_Good(t *testing.T) { - dir := initGitRepo(t) - - branch, err := GetCurrentBranch(context.Background(), dir) - require.NoError(t, err) - // Depending on git config, default branch could be master or main. - assert.True(t, branch == "main" || branch == "master", - "expected main or master, got %q", branch) -} - -func TestGetCurrentBranch_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - _, err := GetCurrentBranch(context.Background(), dir) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to get current branch") -} - -// --- HasUncommittedChanges tests --- - -func TestHasUncommittedChanges_Good_Clean(t *testing.T) { - dir := initGitRepo(t) - - hasChanges, err := HasUncommittedChanges(context.Background(), dir) - require.NoError(t, err) - assert.False(t, hasChanges, "fresh repo with initial commit should be clean") -} - -func TestHasUncommittedChanges_Good_WithChanges(t *testing.T) { - dir := initGitRepo(t) - - // Create a new file. - err := os.WriteFile(filepath.Join(dir, "new-file.txt"), []byte("content"), 0644) - require.NoError(t, err) - - hasChanges, err := HasUncommittedChanges(context.Background(), dir) - require.NoError(t, err) - assert.True(t, hasChanges, "should detect untracked file") -} - -func TestHasUncommittedChanges_Good_WithModifiedFile(t *testing.T) { - dir := initGitRepo(t) - - // Modify the existing README. - err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Updated\n"), 0644) - require.NoError(t, err) - - hasChanges, err := HasUncommittedChanges(context.Background(), dir) - require.NoError(t, err) - assert.True(t, hasChanges, "should detect modified file") -} - -func TestHasUncommittedChanges_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - _, err := HasUncommittedChanges(context.Background(), dir) - assert.Error(t, err) -} - -// --- GetDiff tests --- - -func TestGetDiff_Good_Unstaged(t *testing.T) { - dir := initGitRepo(t) - - // Modify a tracked file. - err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Modified\n"), 0644) - require.NoError(t, err) - - diff, err := GetDiff(context.Background(), dir, false) - require.NoError(t, err) - assert.Contains(t, diff, "Modified", "diff should show the change") -} - -func TestGetDiff_Good_Staged(t *testing.T) { - dir := initGitRepo(t) - - // Modify and stage a file. - err := os.WriteFile(filepath.Join(dir, "README.md"), []byte("# Staged change\n"), 0644) - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "add", "README.md") - require.NoError(t, err) - - diff, err := GetDiff(context.Background(), dir, true) - require.NoError(t, err) - assert.Contains(t, diff, "Staged change", "staged diff should show the change") -} - -func TestGetDiff_Good_NoDiff(t *testing.T) { - dir := initGitRepo(t) - - diff, err := GetDiff(context.Background(), dir, false) - require.NoError(t, err) - assert.Empty(t, diff, "clean repo should have no diff") -} - -func TestGetDiff_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - _, err := GetDiff(context.Background(), dir, false) - assert.Error(t, err) -} - -// --- AutoCommit tests (with real git) --- - -func TestAutoCommit_Good(t *testing.T) { - dir := initGitRepo(t) - - // Create a file to commit. - err := os.WriteFile(filepath.Join(dir, "feature.go"), []byte("package main\n"), 0644) - require.NoError(t, err) - - task := &Task{ID: "T-100", Title: "Add feature"} - err = AutoCommit(context.Background(), task, dir, "feat: add feature module") - require.NoError(t, err) - - // Verify commit was created. - output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1") - require.NoError(t, err) - assert.Contains(t, output, "feat: add feature module") - - // Verify task reference in full message. - fullLog, err := runGitCommandCtx(context.Background(), dir, "log", "-1", "--pretty=format:%B") - require.NoError(t, err) - assert.Contains(t, fullLog, "Task: #T-100") - assert.Contains(t, fullLog, "Co-Authored-By: Claude ") -} - -func TestAutoCommit_Bad_NoChangesToCommit(t *testing.T) { - dir := initGitRepo(t) - - // No changes to commit. - task := &Task{ID: "T-200", Title: "No changes"} - err := AutoCommit(context.Background(), task, dir, "feat: nothing") - assert.Error(t, err, "should fail when there is nothing to commit") -} - -// --- CreateBranch tests (with real git) --- - -func TestCreateBranch_Good(t *testing.T) { - dir := initGitRepo(t) - - task := &Task{ - ID: "BR-42", - Title: "Implement new feature", - Labels: []string{"enhancement"}, - } - - branchName, err := CreateBranch(context.Background(), task, dir) - require.NoError(t, err) - assert.Equal(t, "feat/BR-42-implement-new-feature", branchName) - - // Verify we're on the new branch. - currentBranch, err := GetCurrentBranch(context.Background(), dir) - require.NoError(t, err) - assert.Equal(t, branchName, currentBranch) -} - -func TestCreateBranch_Good_BugLabel(t *testing.T) { - dir := initGitRepo(t) - - task := &Task{ - ID: "BR-43", - Title: "Fix login bug", - Labels: []string{"bug"}, - } - - branchName, err := CreateBranch(context.Background(), task, dir) - require.NoError(t, err) - assert.Equal(t, "fix/BR-43-fix-login-bug", branchName) -} - -// --- PushChanges test --- - -func TestPushChanges_Bad_NoRemote(t *testing.T) { - dir := initGitRepo(t) - - // No remote configured, push should fail. - err := PushChanges(context.Background(), dir) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to push changes") -} - -// --- CommitAndSync tests --- - -func TestCommitAndSync_Good_WithoutClient(t *testing.T) { - dir := initGitRepo(t) - - // Create a file to commit. - err := os.WriteFile(filepath.Join(dir, "sync.go"), []byte("package sync\n"), 0644) - require.NoError(t, err) - - task := &Task{ID: "CS-1", Title: "Sync test"} - - // nil client: should commit but skip sync. - err = CommitAndSync(context.Background(), nil, task, dir, "feat: sync test", 50) - require.NoError(t, err) - - // Verify commit. - output, err := runGitCommandCtx(context.Background(), dir, "log", "--oneline", "-1") - require.NoError(t, err) - assert.Contains(t, output, "feat: sync test") -} - -func TestCommitAndSync_Good_WithClient(t *testing.T) { - dir := initGitRepo(t) - - // Create a file to commit. - err := os.WriteFile(filepath.Join(dir, "synced.go"), []byte("package synced\n"), 0644) - require.NoError(t, err) - - var receivedUpdate TaskUpdate - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPatch { - _ = json.NewDecoder(r.Body).Decode(&receivedUpdate) - w.WriteHeader(http.StatusOK) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task := &Task{ID: "CS-2", Title: "Sync with client"} - - err = CommitAndSync(context.Background(), client, task, dir, "feat: synced", 75) - require.NoError(t, err) - - // Verify the update was sent. - assert.Equal(t, StatusInProgress, receivedUpdate.Status) - assert.Equal(t, 75, receivedUpdate.Progress) - assert.Contains(t, receivedUpdate.Notes, "feat: synced") -} - -func TestCommitAndSync_Bad_CommitFails(t *testing.T) { - dir := initGitRepo(t) - - // No changes to commit. - task := &Task{ID: "CS-3", Title: "Will fail"} - - err := CommitAndSync(context.Background(), nil, task, dir, "feat: no changes", 50) - assert.Error(t, err, "should fail when commit fails") -} - -func TestCommitAndSync_Bad_SyncFails(t *testing.T) { - dir := initGitRepo(t) - - // Create a file to commit. - err := os.WriteFile(filepath.Join(dir, "fail-sync.go"), []byte("package failsync\n"), 0644) - require.NoError(t, err) - - // Server returns an error. - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "sync failed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task := &Task{ID: "CS-4", Title: "Sync fails"} - - err = CommitAndSync(context.Background(), client, task, dir, "feat: sync-fail", 50) - assert.Error(t, err, "should report sync failure") - assert.Contains(t, err.Error(), "sync failed") -} - -// --- SyncStatus with working client --- - -func TestSyncStatus_Good(t *testing.T) { - var receivedUpdate TaskUpdate - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewDecoder(r.Body).Decode(&receivedUpdate) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task := &Task{ID: "sync-1", Title: "Sync test"} - - err := SyncStatus(context.Background(), client, task, TaskUpdate{ - Status: StatusCompleted, - Progress: 100, - Notes: "All done", - }) - require.NoError(t, err) - assert.Equal(t, StatusCompleted, receivedUpdate.Status) - assert.Equal(t, 100, receivedUpdate.Progress) -} - -// --- CreatePR with default title/body --- - -func TestCreatePR_Good_DefaultTitleFromTask(t *testing.T) { - // CreatePR requires gh CLI which may not be available. - // Test the option building logic by checking that the title - // defaults to the task title. - task := &Task{ - ID: "PR-1", - Title: "Add authentication", - Description: "OAuth2 login", - Priority: PriorityHigh, - } - - opts := PROptions{} - - // Verify the defaulting logic that would be used. - title := opts.Title - if title == "" { - title = task.Title - } - assert.Equal(t, "Add authentication", title) - - body := opts.Body - if body == "" { - body = buildPRBody(task) - } - assert.Contains(t, body, "OAuth2 login") -} - -func TestCreatePR_Good_CustomOptions(t *testing.T) { - opts := PROptions{ - Title: "Custom title", - Body: "Custom body", - Draft: true, - Labels: []string{"enhancement", "v2"}, - Base: "develop", - } - - assert.Equal(t, "Custom title", opts.Title) - assert.True(t, opts.Draft) - assert.Equal(t, "develop", opts.Base) - assert.Len(t, opts.Labels, 2) -} - -// --- Client checkResponse edge cases --- - -func TestClient_CheckResponse_Good_GenericError(t *testing.T) { - // Test checkResponse with a non-JSON error body. - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - _, _ = w.Write([]byte("plain text error")) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - _, err := client.GetTask(context.Background(), "test-task") - assert.Error(t, err) - assert.Contains(t, err.Error(), "Bad Gateway") -} - -func TestClient_CheckResponse_Good_EmptyBody(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusForbidden) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - _, err := client.GetTask(context.Background(), "test-task") - assert.Error(t, err) - assert.Contains(t, err.Error(), "Forbidden") -} - -// --- Client Ping edge case --- - -func TestClient_Ping_Bad_ServerReturns4xx(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.Ping(context.Background()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "status 401") -} - -// --- Client ClaimTask without AgentID --- - -func TestClient_ClaimTask_Good_NoAgentID(t *testing.T) { - claimedTask := Task{ - ID: "task-no-agent", - Status: StatusInProgress, - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify no body sent when AgentID is empty. - assert.Equal(t, http.MethodPost, r.Method) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(claimedTask) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - // Explicitly leave AgentID empty. - client.AgentID = "" - - task, err := client.ClaimTask(context.Background(), "task-no-agent") - require.NoError(t, err) - assert.Equal(t, "task-no-agent", task.ID) -} diff --git a/pkg/lifecycle/completion_test.go b/pkg/lifecycle/completion_test.go deleted file mode 100644 index 434429f..0000000 --- a/pkg/lifecycle/completion_test.go +++ /dev/null @@ -1,199 +0,0 @@ -package lifecycle - -import ( - "context" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBuildCommitMessage(t *testing.T) { - task := &Task{ - ID: "ABC123", - Title: "Test Task", - } - - message := buildCommitMessage(task, "add new feature") - - assert.Contains(t, message, "add new feature") - assert.Contains(t, message, "Task: #ABC123") - assert.Contains(t, message, "Co-Authored-By: Claude ") -} - -func TestBuildPRBody(t *testing.T) { - task := &Task{ - ID: "PR-456", - Title: "Add authentication", - Description: "Implement user authentication with OAuth2", - Priority: PriorityHigh, - Labels: []string{"enhancement", "security"}, - } - - body := buildPRBody(task) - - assert.Contains(t, body, "## Summary") - assert.Contains(t, body, "Implement user authentication with OAuth2") - assert.Contains(t, body, "## Task Reference") - assert.Contains(t, body, "Task ID: #PR-456") - assert.Contains(t, body, "Priority: high") - assert.Contains(t, body, "Labels: enhancement, security") - assert.Contains(t, body, "Generated with AI assistance") -} - -func TestBuildPRBody_NoLabels(t *testing.T) { - task := &Task{ - ID: "PR-789", - Title: "Fix bug", - Description: "Fix the login bug", - Priority: PriorityMedium, - Labels: nil, - } - - body := buildPRBody(task) - - assert.Contains(t, body, "## Summary") - assert.Contains(t, body, "Fix the login bug") - assert.NotContains(t, body, "Labels:") -} - -func TestGenerateBranchName(t *testing.T) { - tests := []struct { - name string - task *Task - expected string - }{ - { - name: "feature task", - task: &Task{ - ID: "123", - Title: "Add user authentication", - Labels: []string{"enhancement"}, - }, - expected: "feat/123-add-user-authentication", - }, - { - name: "bug fix task", - task: &Task{ - ID: "456", - Title: "Fix login error", - Labels: []string{"bug"}, - }, - expected: "fix/456-fix-login-error", - }, - { - name: "docs task", - task: &Task{ - ID: "789", - Title: "Update README", - Labels: []string{"documentation"}, - }, - expected: "docs/789-update-readme", - }, - { - name: "refactor task", - task: &Task{ - ID: "101", - Title: "Refactor auth module", - Labels: []string{"refactor"}, - }, - expected: "refactor/101-refactor-auth-module", - }, - { - name: "test task", - task: &Task{ - ID: "202", - Title: "Add unit tests", - Labels: []string{"test"}, - }, - expected: "test/202-add-unit-tests", - }, - { - name: "chore task", - task: &Task{ - ID: "303", - Title: "Update dependencies", - Labels: []string{"chore"}, - }, - expected: "chore/303-update-dependencies", - }, - { - name: "long title truncated", - task: &Task{ - ID: "404", - Title: "This is a very long title that should be truncated to fit the branch name limit", - Labels: nil, - }, - expected: "feat/404-this-is-a-very-long-title-that-should-be", - }, - { - name: "special characters removed", - task: &Task{ - ID: "505", - Title: "Fix: user's auth (OAuth2) [important]", - Labels: nil, - }, - expected: "feat/505-fix-users-auth-oauth2-important", - }, - { - name: "no labels defaults to feat", - task: &Task{ - ID: "606", - Title: "New feature", - Labels: nil, - }, - expected: "feat/606-new-feature", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := generateBranchName(tt.task) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestAutoCommit_Bad_NilTask(t *testing.T) { - err := AutoCommit(context.TODO(), nil, ".", "test message") - assert.Error(t, err) - assert.Contains(t, err.Error(), "task is required") -} - -func TestAutoCommit_Bad_EmptyMessage(t *testing.T) { - task := &Task{ID: "123", Title: "Test"} - err := AutoCommit(context.TODO(), task, ".", "") - assert.Error(t, err) - assert.Contains(t, err.Error(), "commit message is required") -} - -func TestSyncStatus_Bad_NilClient(t *testing.T) { - task := &Task{ID: "123", Title: "Test"} - update := TaskUpdate{Status: StatusInProgress} - - err := SyncStatus(context.TODO(), nil, task, update) - assert.Error(t, err) - assert.Contains(t, err.Error(), "client is required") -} - -func TestSyncStatus_Bad_NilTask(t *testing.T) { - client := &Client{BaseURL: "http://test"} - update := TaskUpdate{Status: StatusInProgress} - - err := SyncStatus(context.TODO(), client, nil, update) - assert.Error(t, err) - assert.Contains(t, err.Error(), "task is required") -} - -func TestCreateBranch_Bad_NilTask(t *testing.T) { - branch, err := CreateBranch(context.TODO(), nil, ".") - assert.Error(t, err) - assert.Empty(t, branch) - assert.Contains(t, err.Error(), "task is required") -} - -func TestCreatePR_Bad_NilTask(t *testing.T) { - url, err := CreatePR(context.TODO(), nil, ".", PROptions{}) - assert.Error(t, err) - assert.Empty(t, url) - assert.Contains(t, err.Error(), "task is required") -} diff --git a/pkg/lifecycle/config.go b/pkg/lifecycle/config.go deleted file mode 100644 index 767a635..0000000 --- a/pkg/lifecycle/config.go +++ /dev/null @@ -1,293 +0,0 @@ -package lifecycle - -import ( - "os" - "path/filepath" - "strings" - - "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-log" - "gopkg.in/yaml.v3" -) - -// Config holds the configuration for connecting to the core-agentic service. -type Config struct { - // BaseURL is the URL of the core-agentic API server. - BaseURL string `yaml:"base_url" json:"base_url"` - // Token is the authentication token for API requests. - Token string `yaml:"token" json:"token"` - // DefaultProject is the project to use when none is specified. - DefaultProject string `yaml:"default_project" json:"default_project"` - // AgentID is the identifier for this agent (optional, used for claiming tasks). - AgentID string `yaml:"agent_id" json:"agent_id"` -} - -// configFileName is the name of the YAML config file. -const configFileName = "agentic.yaml" - -// envFileName is the name of the environment file. -const envFileName = ".env" - -// DefaultBaseURL is the default API endpoint if none is configured. -// Set AGENTIC_BASE_URL to override: -// - Lab: https://api.lthn.sh -// - Prod: https://api.lthn.ai -const DefaultBaseURL = "https://api.lthn.sh" - -// LoadConfig loads the agentic configuration from the specified directory. -// It first checks for a .env file, then falls back to ~/.core/agentic.yaml. -// If dir is empty, it checks the current directory first. -// -// Environment variables take precedence: -// - AGENTIC_BASE_URL: API base URL -// - AGENTIC_TOKEN: Authentication token -// - AGENTIC_PROJECT: Default project -// - AGENTIC_AGENT_ID: Agent identifier -func LoadConfig(dir string) (*Config, error) { - cfg := &Config{ - BaseURL: DefaultBaseURL, - } - - // Try loading from .env file in the specified directory - if dir != "" { - envPath := filepath.Join(dir, envFileName) - if err := loadEnvFile(envPath, cfg); err == nil { - // Successfully loaded from .env - applyEnvOverrides(cfg) - if cfg.Token != "" { - return cfg, nil - } - } - } - - // Try loading from current directory .env - if dir == "" { - cwd, err := os.Getwd() - if err == nil { - envPath := filepath.Join(cwd, envFileName) - if err := loadEnvFile(envPath, cfg); err == nil { - applyEnvOverrides(cfg) - if cfg.Token != "" { - return cfg, nil - } - } - } - } - - // Try loading from ~/.core/agentic.yaml - homeDir, err := os.UserHomeDir() - if err != nil { - return nil, log.E("agentic.LoadConfig", "failed to get home directory", err) - } - - configPath := filepath.Join(homeDir, ".core", configFileName) - if err := loadYAMLConfig(configPath, cfg); err != nil && !os.IsNotExist(err) { - return nil, log.E("agentic.LoadConfig", "failed to load config", err) - } - - // Apply environment variable overrides - applyEnvOverrides(cfg) - - // Validate configuration - if cfg.Token == "" { - return nil, log.E("agentic.LoadConfig", "no authentication token configured", nil) - } - - return cfg, nil -} - -// loadEnvFile reads a .env file and extracts agentic configuration. -func loadEnvFile(path string, cfg *Config) error { - content, err := io.Local.Read(path) - if err != nil { - return err - } - - for line := range strings.SplitSeq(content, "\n") { - line = strings.TrimSpace(line) - - // Skip empty lines and comments - if line == "" || strings.HasPrefix(line, "#") { - continue - } - - // Parse KEY=value - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - - // Remove quotes if present - value = strings.Trim(value, `"'`) - - switch key { - case "AGENTIC_BASE_URL": - cfg.BaseURL = value - case "AGENTIC_TOKEN": - cfg.Token = value - case "AGENTIC_PROJECT": - cfg.DefaultProject = value - case "AGENTIC_AGENT_ID": - cfg.AgentID = value - } - } - - return nil -} - -// loadYAMLConfig reads configuration from a YAML file. -func loadYAMLConfig(path string, cfg *Config) error { - content, err := io.Local.Read(path) - if err != nil { - return err - } - - return yaml.Unmarshal([]byte(content), cfg) -} - -// applyEnvOverrides applies environment variable overrides to the config. -func applyEnvOverrides(cfg *Config) { - if v := os.Getenv("AGENTIC_BASE_URL"); v != "" { - cfg.BaseURL = v - } - if v := os.Getenv("AGENTIC_TOKEN"); v != "" { - cfg.Token = v - } - if v := os.Getenv("AGENTIC_PROJECT"); v != "" { - cfg.DefaultProject = v - } - if v := os.Getenv("AGENTIC_AGENT_ID"); v != "" { - cfg.AgentID = v - } -} - -// SaveConfig saves the configuration to ~/.core/agentic.yaml. -func SaveConfig(cfg *Config) error { - homeDir, err := os.UserHomeDir() - if err != nil { - return log.E("agentic.SaveConfig", "failed to get home directory", err) - } - - configDir := filepath.Join(homeDir, ".core") - if err := io.Local.EnsureDir(configDir); err != nil { - return log.E("agentic.SaveConfig", "failed to create config directory", err) - } - - configPath := filepath.Join(configDir, configFileName) - - data, err := yaml.Marshal(cfg) - if err != nil { - return log.E("agentic.SaveConfig", "failed to marshal config", err) - } - - if err := io.Local.Write(configPath, string(data)); err != nil { - return log.E("agentic.SaveConfig", "failed to write config file", err) - } - - return nil -} - -// ConfigPath returns the path to the config file in the user's home directory. -func ConfigPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", log.E("agentic.ConfigPath", "failed to get home directory", err) - } - return filepath.Join(homeDir, ".core", configFileName), nil -} - -// AllowanceConfig controls allowance store backend selection. -type AllowanceConfig struct { - // StoreBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory". - StoreBackend string `yaml:"store_backend" json:"store_backend"` - // StorePath is the file path for the SQLite database. - // Default: ~/.config/agentic/allowance.db (only used when StoreBackend is "sqlite"). - StorePath string `yaml:"store_path" json:"store_path"` - // RedisAddr is the host:port for the Redis server (only used when StoreBackend is "redis"). - RedisAddr string `yaml:"redis_addr" json:"redis_addr"` -} - -// DefaultAllowanceStorePath returns the default SQLite path for allowance data. -func DefaultAllowanceStorePath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", log.E("agentic.DefaultAllowanceStorePath", "failed to get home directory", err) - } - return filepath.Join(homeDir, ".config", "agentic", "allowance.db"), nil -} - -// NewAllowanceStoreFromConfig creates an AllowanceStore based on the given config. -// It returns a MemoryStore for "memory" (or empty) backend and a SQLiteStore for "sqlite". -func NewAllowanceStoreFromConfig(cfg AllowanceConfig) (AllowanceStore, error) { - switch cfg.StoreBackend { - case "", "memory": - return NewMemoryStore(), nil - case "sqlite": - dbPath := cfg.StorePath - if dbPath == "" { - var err error - dbPath, err = DefaultAllowanceStorePath() - if err != nil { - return nil, err - } - } - return NewSQLiteStore(dbPath) - case "redis": - return NewRedisStore(cfg.RedisAddr) - default: - return nil, &APIError{ - Code: 400, - Message: "unsupported store backend: " + cfg.StoreBackend, - } - } -} - -// RegistryConfig controls agent registry backend selection. -type RegistryConfig struct { - // RegistryBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory". - RegistryBackend string `yaml:"registry_backend" json:"registry_backend"` - // RegistryPath is the file path for the SQLite database. - // Default: ~/.config/agentic/registry.db (only used when RegistryBackend is "sqlite"). - RegistryPath string `yaml:"registry_path" json:"registry_path"` - // RegistryRedisAddr is the host:port for the Redis server (only used when RegistryBackend is "redis"). - RegistryRedisAddr string `yaml:"registry_redis_addr" json:"registry_redis_addr"` -} - -// DefaultRegistryPath returns the default SQLite path for registry data. -func DefaultRegistryPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", log.E("agentic.DefaultRegistryPath", "failed to get home directory", err) - } - return filepath.Join(homeDir, ".config", "agentic", "registry.db"), nil -} - -// NewAgentRegistryFromConfig creates an AgentRegistry based on the given config. -// It returns a MemoryRegistry for "memory" (or empty) backend, a SQLiteRegistry -// for "sqlite", and a RedisRegistry for "redis". -func NewAgentRegistryFromConfig(cfg RegistryConfig) (AgentRegistry, error) { - switch cfg.RegistryBackend { - case "", "memory": - return NewMemoryRegistry(), nil - case "sqlite": - dbPath := cfg.RegistryPath - if dbPath == "" { - var err error - dbPath, err = DefaultRegistryPath() - if err != nil { - return nil, err - } - } - return NewSQLiteRegistry(dbPath) - case "redis": - return NewRedisRegistry(cfg.RegistryRedisAddr) - default: - return nil, &APIError{ - Code: 400, - Message: "unsupported registry backend: " + cfg.RegistryBackend, - } - } -} diff --git a/pkg/lifecycle/config_test.go b/pkg/lifecycle/config_test.go deleted file mode 100644 index 3e76e2d..0000000 --- a/pkg/lifecycle/config_test.go +++ /dev/null @@ -1,455 +0,0 @@ -package lifecycle - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLoadConfig_Good_FromEnvFile(t *testing.T) { - // Create temp directory with .env file - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - envContent := ` -AGENTIC_BASE_URL=https://test.api.com -AGENTIC_TOKEN=test-token-123 -AGENTIC_PROJECT=my-project -AGENTIC_AGENT_ID=agent-001 -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - cfg, err := LoadConfig(tmpDir) - - require.NoError(t, err) - assert.Equal(t, "https://test.api.com", cfg.BaseURL) - assert.Equal(t, "test-token-123", cfg.Token) - assert.Equal(t, "my-project", cfg.DefaultProject) - assert.Equal(t, "agent-001", cfg.AgentID) -} - -func TestLoadConfig_Good_FromEnvVars(t *testing.T) { - // Create temp directory with .env file (partial config) - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - envContent := ` -AGENTIC_TOKEN=env-file-token -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - // Set environment variables that should override - _ = os.Setenv("AGENTIC_BASE_URL", "https://env-override.com") - _ = os.Setenv("AGENTIC_TOKEN", "env-override-token") - defer func() { - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_TOKEN") - }() - - cfg, err := LoadConfig(tmpDir) - - require.NoError(t, err) - assert.Equal(t, "https://env-override.com", cfg.BaseURL) - assert.Equal(t, "env-override-token", cfg.Token) -} - -func TestLoadConfig_Bad_NoToken(t *testing.T) { - // Create temp directory without config - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Create empty .env - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(""), 0644) - require.NoError(t, err) - - // Ensure no env vars are set - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - - _, err = LoadConfig(tmpDir) - - assert.Error(t, err) - assert.Contains(t, err.Error(), "no authentication token") -} - -func TestLoadConfig_Good_EnvFileWithQuotes(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Test with quoted values - envContent := ` -AGENTIC_TOKEN="quoted-token" -AGENTIC_BASE_URL='single-quoted-url' -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - cfg, err := LoadConfig(tmpDir) - - require.NoError(t, err) - assert.Equal(t, "quoted-token", cfg.Token) - assert.Equal(t, "single-quoted-url", cfg.BaseURL) -} - -func TestLoadConfig_Good_EnvFileWithComments(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - envContent := ` -# This is a comment -AGENTIC_TOKEN=token-with-comments - -# Another comment -AGENTIC_PROJECT=commented-project -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - cfg, err := LoadConfig(tmpDir) - - require.NoError(t, err) - assert.Equal(t, "token-with-comments", cfg.Token) - assert.Equal(t, "commented-project", cfg.DefaultProject) -} - -func TestSaveConfig_Good(t *testing.T) { - // Create temp home directory - tmpHome, err := os.MkdirTemp("", "agentic-home") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpHome) }() - - // Override HOME for the test - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - cfg := &Config{ - BaseURL: "https://saved.api.com", - Token: "saved-token", - DefaultProject: "saved-project", - AgentID: "saved-agent", - } - - err = SaveConfig(cfg) - require.NoError(t, err) - - // Verify file was created - configPath := filepath.Join(tmpHome, ".core", "agentic.yaml") - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Read back the config - data, err := os.ReadFile(configPath) - require.NoError(t, err) - assert.Contains(t, string(data), "saved.api.com") - assert.Contains(t, string(data), "saved-token") -} - -func TestConfigPath_Good(t *testing.T) { - path, err := ConfigPath() - - require.NoError(t, err) - assert.Contains(t, path, ".core") - assert.Contains(t, path, "agentic.yaml") -} - -func TestLoadConfig_Good_DefaultBaseURL(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Only provide token, should use default base URL - envContent := ` -AGENTIC_TOKEN=test-token -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - // Clear any env overrides - _ = os.Unsetenv("AGENTIC_BASE_URL") - - cfg, err := LoadConfig(tmpDir) - - require.NoError(t, err) - assert.Equal(t, DefaultBaseURL, cfg.BaseURL) -} - -func TestLoadConfig_Good_FromYAMLFallback(t *testing.T) { - // Set up a temp home with ~/.core/agentic.yaml - tmpHome, err := os.MkdirTemp("", "agentic-home") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpHome) }() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - // Clear all env vars so we fall through to YAML. - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - // Create ~/.core/agentic.yaml - configDir := filepath.Join(tmpHome, ".core") - err = os.MkdirAll(configDir, 0755) - require.NoError(t, err) - - yamlContent := `base_url: https://yaml.api.com -token: yaml-token -default_project: yaml-project -agent_id: yaml-agent -` - err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644) - require.NoError(t, err) - - // Load from a dir with no .env to force YAML fallback. - tmpDir, err := os.MkdirTemp("", "agentic-noenv") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - cfg, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, "https://yaml.api.com", cfg.BaseURL) - assert.Equal(t, "yaml-token", cfg.Token) - assert.Equal(t, "yaml-project", cfg.DefaultProject) - assert.Equal(t, "yaml-agent", cfg.AgentID) -} - -func TestLoadConfig_Good_EnvOverridesYAML(t *testing.T) { - // Set up a temp home with ~/.core/agentic.yaml - tmpHome, err := os.MkdirTemp("", "agentic-home") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpHome) }() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - // Create ~/.core/agentic.yaml - configDir := filepath.Join(tmpHome, ".core") - err = os.MkdirAll(configDir, 0755) - require.NoError(t, err) - - yamlContent := `base_url: https://yaml.api.com -token: yaml-token -` - err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644) - require.NoError(t, err) - - // Set env overrides for project and agent. - _ = os.Setenv("AGENTIC_PROJECT", "env-project") - _ = os.Setenv("AGENTIC_AGENT_ID", "env-agent") - defer func() { - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - }() - - tmpDir, err := os.MkdirTemp("", "agentic-noenv") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - cfg, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, "env-project", cfg.DefaultProject, "env var should override YAML") - assert.Equal(t, "env-agent", cfg.AgentID, "env var should override YAML") -} - -func TestLoadConfig_Good_EnvFileWithTokenNoOverride(t *testing.T) { - // Test that .env with a token returns immediately without - // falling through to YAML. - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - envContent := `AGENTIC_TOKEN=env-file-only` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - cfg, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, "env-file-only", cfg.Token) - assert.Equal(t, DefaultBaseURL, cfg.BaseURL) -} - -func TestLoadConfig_Good_EnvFileWithMalformedLines(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - // Lines without = sign should be skipped. - envContent := ` -AGENTIC_TOKEN=valid-token -MALFORMED_LINE_NO_EQUALS -ANOTHER_BAD_LINE -AGENTIC_PROJECT=valid-project -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - cfg, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, "valid-token", cfg.Token) - assert.Equal(t, "valid-project", cfg.DefaultProject) -} - -func TestApplyEnvOverrides_Good_AllVars(t *testing.T) { - _ = os.Setenv("AGENTIC_BASE_URL", "https://override-url.com") - _ = os.Setenv("AGENTIC_TOKEN", "override-token") - _ = os.Setenv("AGENTIC_PROJECT", "override-project") - _ = os.Setenv("AGENTIC_AGENT_ID", "override-agent") - defer func() { - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - }() - - cfg := &Config{} - applyEnvOverrides(cfg) - - assert.Equal(t, "https://override-url.com", cfg.BaseURL) - assert.Equal(t, "override-token", cfg.Token) - assert.Equal(t, "override-project", cfg.DefaultProject) - assert.Equal(t, "override-agent", cfg.AgentID) -} - -func TestApplyEnvOverrides_Good_NoVarsSet(t *testing.T) { - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - cfg := &Config{ - BaseURL: "original-url", - Token: "original-token", - } - applyEnvOverrides(cfg) - - assert.Equal(t, "original-url", cfg.BaseURL, "should not change without env var") - assert.Equal(t, "original-token", cfg.Token, "should not change without env var") -} - -func TestSaveConfig_Good_RoundTrip(t *testing.T) { - tmpHome, err := os.MkdirTemp("", "agentic-home") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpHome) }() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - // Clear env vars so LoadConfig falls through to YAML. - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - original := &Config{ - BaseURL: "https://roundtrip.api.com", - Token: "roundtrip-token", - DefaultProject: "roundtrip-project", - AgentID: "roundtrip-agent", - } - - err = SaveConfig(original) - require.NoError(t, err) - - // Load it back by pointing to a dir with no .env. - tmpDir, err := os.MkdirTemp("", "agentic-noenv") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - loaded, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, original.BaseURL, loaded.BaseURL) - assert.Equal(t, original.Token, loaded.Token) - assert.Equal(t, original.DefaultProject, loaded.DefaultProject) - assert.Equal(t, original.AgentID, loaded.AgentID) -} - -func TestLoadConfig_Good_EmptyDirUsesCurrentDir(t *testing.T) { - // Create a temp directory with .env and chdir into it. - tmpDir, err := os.MkdirTemp("", "agentic-cwd") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - envContent := `AGENTIC_TOKEN=cwd-token -AGENTIC_BASE_URL=https://cwd.api.com -` - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte(envContent), 0644) - require.NoError(t, err) - - // Clear env vars. - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - // Save and restore cwd. - originalCwd, err := os.Getwd() - require.NoError(t, err) - defer func() { _ = os.Chdir(originalCwd) }() - - err = os.Chdir(tmpDir) - require.NoError(t, err) - - cfg, err := LoadConfig("") - require.NoError(t, err) - assert.Equal(t, "cwd-token", cfg.Token) - assert.Equal(t, "https://cwd.api.com", cfg.BaseURL) -} - -func TestLoadConfig_Good_EnvFileNoToken_FallsToYAML(t *testing.T) { - tmpHome, err := os.MkdirTemp("", "agentic-home") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpHome) }() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - _ = os.Unsetenv("AGENTIC_TOKEN") - _ = os.Unsetenv("AGENTIC_BASE_URL") - _ = os.Unsetenv("AGENTIC_PROJECT") - _ = os.Unsetenv("AGENTIC_AGENT_ID") - - // Create .env without a token. - tmpDir, err := os.MkdirTemp("", "agentic-test") - require.NoError(t, err) - defer func() { _ = os.RemoveAll(tmpDir) }() - - err = os.WriteFile(filepath.Join(tmpDir, ".env"), []byte("AGENTIC_PROJECT=env-proj\n"), 0644) - require.NoError(t, err) - - // Create YAML with token. - configDir := filepath.Join(tmpHome, ".core") - err = os.MkdirAll(configDir, 0755) - require.NoError(t, err) - - yamlContent := `token: yaml-fallback-token -` - err = os.WriteFile(filepath.Join(configDir, "agentic.yaml"), []byte(yamlContent), 0644) - require.NoError(t, err) - - cfg, err := LoadConfig(tmpDir) - require.NoError(t, err) - assert.Equal(t, "yaml-fallback-token", cfg.Token) -} diff --git a/pkg/lifecycle/context.go b/pkg/lifecycle/context.go deleted file mode 100644 index 47bb054..0000000 --- a/pkg/lifecycle/context.go +++ /dev/null @@ -1,335 +0,0 @@ -// Package agentic provides AI collaboration features for task management. -package lifecycle - -import ( - "bytes" - "os" - "os/exec" - "path/filepath" - "regexp" - "strings" - - "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-log" -) - -// FileContent represents the content of a file for AI context. -type FileContent struct { - // Path is the relative path to the file. - Path string `json:"path"` - // Content is the file content. - Content string `json:"content"` - // Language is the detected programming language. - Language string `json:"language"` -} - -// TaskContext contains gathered context for AI collaboration. -type TaskContext struct { - // Task is the task being worked on. - Task *Task `json:"task"` - // Files is a list of relevant file contents. - Files []FileContent `json:"files"` - // GitStatus is the current git status output. - GitStatus string `json:"git_status"` - // RecentCommits is the recent commit log. - RecentCommits string `json:"recent_commits"` - // RelatedCode contains code snippets related to the task. - RelatedCode []FileContent `json:"related_code"` -} - -// BuildTaskContext gathers context for AI collaboration on a task. -func BuildTaskContext(task *Task, dir string) (*TaskContext, error) { - const op = "agentic.BuildTaskContext" - - if task == nil { - return nil, log.E(op, "task is required", nil) - } - - if dir == "" { - cwd, err := os.Getwd() - if err != nil { - return nil, log.E(op, "failed to get working directory", err) - } - dir = cwd - } - - ctx := &TaskContext{ - Task: task, - } - - // Gather files mentioned in the task - files, err := GatherRelatedFiles(task, dir) - if err != nil { - // Non-fatal: continue without files - files = nil - } - ctx.Files = files - - // Get git status - gitStatus, _ := runGitCommand(dir, "status", "--porcelain") - ctx.GitStatus = gitStatus - - // Get recent commits - recentCommits, _ := runGitCommand(dir, "log", "--oneline", "-10") - ctx.RecentCommits = recentCommits - - // Find related code by searching for keywords - relatedCode, err := findRelatedCode(task, dir) - if err != nil { - relatedCode = nil - } - ctx.RelatedCode = relatedCode - - return ctx, nil -} - -// GatherRelatedFiles reads files mentioned in the task. -func GatherRelatedFiles(task *Task, dir string) ([]FileContent, error) { - const op = "agentic.GatherRelatedFiles" - - if task == nil { - return nil, log.E(op, "task is required", nil) - } - - var files []FileContent - - // Read files explicitly mentioned in the task - for _, relPath := range task.Files { - fullPath := filepath.Join(dir, relPath) - - content, err := io.Local.Read(fullPath) - if err != nil { - // Skip files that don't exist - continue - } - - files = append(files, FileContent{ - Path: relPath, - Content: content, - Language: detectLanguage(relPath), - }) - } - - return files, nil -} - -// findRelatedCode searches for code related to the task by keywords. -func findRelatedCode(task *Task, dir string) ([]FileContent, error) { - const op = "agentic.findRelatedCode" - - if task == nil { - return nil, log.E(op, "task is required", nil) - } - - // Extract keywords from title and description - keywords := extractKeywords(task.Title + " " + task.Description) - if len(keywords) == 0 { - return nil, nil - } - - var files []FileContent - seen := make(map[string]bool) - - // Search for each keyword using git grep - for _, keyword := range keywords { - if len(keyword) < 3 { - continue - } - - output, err := runGitCommand(dir, "grep", "-l", "-i", keyword, "--", "*.go", "*.ts", "*.js", "*.py") - if err != nil { - continue - } - - // Parse matched files - for line := range strings.SplitSeq(output, "\n") { - line = strings.TrimSpace(line) - if line == "" || seen[line] { - continue - } - seen[line] = true - - // Limit to 10 related files - if len(files) >= 10 { - break - } - - fullPath := filepath.Join(dir, line) - content, err := io.Local.Read(fullPath) - if err != nil { - continue - } - - // Truncate large files - if len(content) > 5000 { - content = content[:5000] + "\n... (truncated)" - } - - files = append(files, FileContent{ - Path: line, - Content: content, - Language: detectLanguage(line), - }) - } - - if len(files) >= 10 { - break - } - } - - return files, nil -} - -// extractKeywords extracts meaningful words from text for searching. -func extractKeywords(text string) []string { - // Remove common words and extract identifiers - text = strings.ToLower(text) - - // Split by non-alphanumeric characters - re := regexp.MustCompile(`[^a-zA-Z0-9]+`) - words := re.Split(text, -1) - - // Filter stop words and short words - stopWords := map[string]bool{ - "the": true, "a": true, "an": true, "and": true, "or": true, "but": true, - "in": true, "on": true, "at": true, "to": true, "for": true, "of": true, - "with": true, "by": true, "from": true, "is": true, "are": true, "was": true, - "be": true, "been": true, "being": true, "have": true, "has": true, "had": true, - "do": true, "does": true, "did": true, "will": true, "would": true, "could": true, - "should": true, "may": true, "might": true, "must": true, "shall": true, - "this": true, "that": true, "these": true, "those": true, "it": true, - "add": true, "create": true, "update": true, "fix": true, "remove": true, - "implement": true, "new": true, "file": true, "code": true, - } - - var keywords []string - for _, word := range words { - word = strings.TrimSpace(word) - if len(word) >= 3 && !stopWords[word] { - keywords = append(keywords, word) - } - } - - // Limit to first 5 keywords - if len(keywords) > 5 { - keywords = keywords[:5] - } - - return keywords -} - -// detectLanguage detects the programming language from a file extension. -func detectLanguage(path string) string { - ext := strings.ToLower(filepath.Ext(path)) - - languages := map[string]string{ - ".go": "go", - ".ts": "typescript", - ".tsx": "typescript", - ".js": "javascript", - ".jsx": "javascript", - ".py": "python", - ".rs": "rust", - ".java": "java", - ".kt": "kotlin", - ".swift": "swift", - ".c": "c", - ".cpp": "cpp", - ".h": "c", - ".hpp": "cpp", - ".rb": "ruby", - ".php": "php", - ".cs": "csharp", - ".fs": "fsharp", - ".scala": "scala", - ".sh": "bash", - ".bash": "bash", - ".zsh": "zsh", - ".yaml": "yaml", - ".yml": "yaml", - ".json": "json", - ".xml": "xml", - ".html": "html", - ".css": "css", - ".scss": "scss", - ".sql": "sql", - ".md": "markdown", - } - - if lang, ok := languages[ext]; ok { - return lang - } - return "text" -} - -// runGitCommand runs a git command and returns the output. -func runGitCommand(dir string, args ...string) (string, error) { - cmd := exec.Command("git", args...) - cmd.Dir = dir - - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - if err := cmd.Run(); err != nil { - return "", err - } - - return stdout.String(), nil -} - -// FormatContext formats the TaskContext for AI consumption. -func (tc *TaskContext) FormatContext() string { - var sb strings.Builder - - sb.WriteString("# Task Context\n\n") - - // Task info - sb.WriteString("## Task\n") - sb.WriteString("ID: " + tc.Task.ID + "\n") - sb.WriteString("Title: " + tc.Task.Title + "\n") - sb.WriteString("Priority: " + string(tc.Task.Priority) + "\n") - sb.WriteString("Status: " + string(tc.Task.Status) + "\n") - sb.WriteString("\n### Description\n") - sb.WriteString(tc.Task.Description + "\n\n") - - // Files - if len(tc.Files) > 0 { - sb.WriteString("## Task Files\n") - for _, f := range tc.Files { - sb.WriteString("### " + f.Path + " (" + f.Language + ")\n") - sb.WriteString("```" + f.Language + "\n") - sb.WriteString(f.Content) - sb.WriteString("\n```\n\n") - } - } - - // Git status - if tc.GitStatus != "" { - sb.WriteString("## Git Status\n") - sb.WriteString("```\n") - sb.WriteString(tc.GitStatus) - sb.WriteString("\n```\n\n") - } - - // Recent commits - if tc.RecentCommits != "" { - sb.WriteString("## Recent Commits\n") - sb.WriteString("```\n") - sb.WriteString(tc.RecentCommits) - sb.WriteString("\n```\n\n") - } - - // Related code - if len(tc.RelatedCode) > 0 { - sb.WriteString("## Related Code\n") - for _, f := range tc.RelatedCode { - sb.WriteString("### " + f.Path + " (" + f.Language + ")\n") - sb.WriteString("```" + f.Language + "\n") - sb.WriteString(f.Content) - sb.WriteString("\n```\n\n") - } - } - - return sb.String() -} diff --git a/pkg/lifecycle/context_git_test.go b/pkg/lifecycle/context_git_test.go deleted file mode 100644 index 16243bf..0000000 --- a/pkg/lifecycle/context_git_test.go +++ /dev/null @@ -1,248 +0,0 @@ -package lifecycle - -import ( - "context" - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// initGitRepoWithCode creates a git repo with searchable Go code. -func initGitRepoWithCode(t *testing.T) string { - t.Helper() - dir := initGitRepo(t) - - // Create Go files with known content for git grep. - files := map[string]string{ - "auth.go": `package main - -// Authenticate validates user credentials. -func Authenticate(user, pass string) bool { - return user != "" && pass != "" -} -`, - "handler.go": `package main - -// HandleRequest processes HTTP requests for authentication. -func HandleRequest() { - // authentication logic here -} -`, - "util.go": `package main - -// TokenValidator checks JWT tokens. -func TokenValidator(token string) bool { - return len(token) > 0 -} -`, - } - - for name, content := range files { - err := os.WriteFile(filepath.Join(dir, name), []byte(content), 0644) - require.NoError(t, err) - } - - // Stage and commit all files so git grep can find them. - _, err := runCommandCtx(context.Background(), dir, "git", "add", "-A") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add code files") - require.NoError(t, err) - - return dir -} - -func TestFindRelatedCode_Good_MatchesKeywords(t *testing.T) { - dir := initGitRepoWithCode(t) - - task := &Task{ - ID: "code-1", - Title: "Fix authentication handler", - Description: "The authentication handler needs refactoring", - } - - files, err := findRelatedCode(task, dir) - require.NoError(t, err) - assert.NotEmpty(t, files, "should find files matching 'authentication' keyword") - - // Verify language detection. - for _, f := range files { - assert.Equal(t, "go", f.Language) - } -} - -func TestFindRelatedCode_Good_NoKeywords(t *testing.T) { - dir := initGitRepoWithCode(t) - - task := &Task{ - ID: "code-2", - Title: "do it", // too short, all stop words - Description: "fix the bug in the code", - } - - files, err := findRelatedCode(task, dir) - require.NoError(t, err) - // Keywords extracted from "do it fix the bug in the code" -- most are stop words. - // Only "bug" is 3+ chars and not a stop word, but may not match any files. - // Result can be nil or empty -- both are acceptable. - _ = files -} - -func TestFindRelatedCode_Bad_NilTask(t *testing.T) { - files, err := findRelatedCode(nil, ".") - assert.Error(t, err) - assert.Nil(t, files) -} - -func TestFindRelatedCode_Good_LimitsTo10Files(t *testing.T) { - dir := initGitRepoWithCode(t) - - // Create 15 files all containing the keyword "validation". - for i := range 15 { - name := filepath.Join(dir, "validation_"+string(rune('a'+i))+".go") - content := "package main\n// validation logic\nfunc Validate" + string(rune('A'+i)) + "() {}\n" - err := os.WriteFile(name, []byte(content), 0644) - require.NoError(t, err) - } - - _, err := runCommandCtx(context.Background(), dir, "git", "add", "-A") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add validation files") - require.NoError(t, err) - - task := &Task{ - ID: "code-3", - Title: "validation refactoring", - Description: "Refactor all validation logic", - } - - files, err := findRelatedCode(task, dir) - require.NoError(t, err) - assert.LessOrEqual(t, len(files), 10, "should limit to 10 related files") -} - -func TestFindRelatedCode_Good_TruncatesLargeFiles(t *testing.T) { - dir := initGitRepoWithCode(t) - - // Create a file larger than 5000 chars containing a searchable keyword. - largeContent := "package main\n// largecontent\n" - for len(largeContent) < 6000 { - largeContent += "// This is filler content for testing truncation purposes.\n" - } - err := os.WriteFile(filepath.Join(dir, "large.go"), []byte(largeContent), 0644) - require.NoError(t, err) - - _, err = runCommandCtx(context.Background(), dir, "git", "add", "-A") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "add large file") - require.NoError(t, err) - - task := &Task{ - ID: "code-4", - Title: "largecontent analysis", - Description: "Analyse the largecontent module", - } - - files, err := findRelatedCode(task, dir) - require.NoError(t, err) - - for _, f := range files { - if f.Path == "large.go" { - assert.True(t, len(f.Content) <= 5020, "content should be truncated") - assert.Contains(t, f.Content, "... (truncated)") - return - } - } - // If large.go wasn't found by git grep, that's acceptable. -} - -func TestBuildTaskContext_Good_WithGitRepo(t *testing.T) { - dir := initGitRepoWithCode(t) - - task := &Task{ - ID: "ctx-1", - Title: "Test context building with authentication", - Description: "Build context in a git repo with searchable code", - Priority: PriorityMedium, - Status: StatusPending, - Files: []string{"auth.go"}, - CreatedAt: time.Now(), - } - - ctx, err := BuildTaskContext(task, dir) - require.NoError(t, err) - assert.NotNil(t, ctx) - assert.Equal(t, task, ctx.Task) - - // Should have gathered the auth.go file. - assert.Len(t, ctx.Files, 1) - assert.Equal(t, "auth.go", ctx.Files[0].Path) - assert.Contains(t, ctx.Files[0].Content, "Authenticate") - - // Should have recent commits. - assert.NotEmpty(t, ctx.RecentCommits) - - // Should have found related code. - assert.NotEmpty(t, ctx.RelatedCode, "should find code related to 'authentication'") -} - -func TestBuildTaskContext_Good_EmptyDir(t *testing.T) { - task := &Task{ - ID: "ctx-2", - Title: "Test with empty dir", - Description: "Testing", - Priority: PriorityLow, - Status: StatusPending, - CreatedAt: time.Now(), - } - - // Empty dir defaults to cwd -- BuildTaskContext handles errors gracefully. - ctx, err := BuildTaskContext(task, "") - require.NoError(t, err) - assert.NotNil(t, ctx) -} - -func TestFormatContext_Good_EmptySections(t *testing.T) { - task := &Task{ - ID: "fmt-1", - Title: "Minimal task", - Description: "No files, no git", - Priority: PriorityLow, - Status: StatusPending, - } - - ctx := &TaskContext{ - Task: task, - Files: nil, - GitStatus: "", - RecentCommits: "", - RelatedCode: nil, - } - - formatted := ctx.FormatContext() - - assert.Contains(t, formatted, "# Task Context") - assert.Contains(t, formatted, "fmt-1") - assert.NotContains(t, formatted, "## Task Files") - assert.NotContains(t, formatted, "## Git Status") - assert.NotContains(t, formatted, "## Recent Commits") - assert.NotContains(t, formatted, "## Related Code") -} - -func TestRunGitCommand_Good(t *testing.T) { - dir := initGitRepo(t) - - output, err := runGitCommand(dir, "log", "--oneline", "-1") - require.NoError(t, err) - assert.Contains(t, output, "initial commit") -} - -func TestRunGitCommand_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - _, err := runGitCommand(dir, "status") - assert.Error(t, err) -} diff --git a/pkg/lifecycle/context_test.go b/pkg/lifecycle/context_test.go deleted file mode 100644 index 97ff9df..0000000 --- a/pkg/lifecycle/context_test.go +++ /dev/null @@ -1,214 +0,0 @@ -package lifecycle - -import ( - "os" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestBuildTaskContext_Good(t *testing.T) { - // Create a temp directory with some files - tmpDir := t.TempDir() - - // Create a test file - testFile := filepath.Join(tmpDir, "main.go") - err := os.WriteFile(testFile, []byte("package main\n\nfunc main() {}\n"), 0644) - require.NoError(t, err) - - task := &Task{ - ID: "test-123", - Title: "Test Task", - Description: "A test task description", - Priority: PriorityMedium, - Status: StatusPending, - Files: []string{"main.go"}, - CreatedAt: time.Now(), - } - - ctx, err := BuildTaskContext(task, tmpDir) - require.NoError(t, err) - assert.NotNil(t, ctx) - assert.Equal(t, task, ctx.Task) - assert.Len(t, ctx.Files, 1) - assert.Equal(t, "main.go", ctx.Files[0].Path) - assert.Equal(t, "go", ctx.Files[0].Language) -} - -func TestBuildTaskContext_Bad_NilTask(t *testing.T) { - ctx, err := BuildTaskContext(nil, ".") - assert.Error(t, err) - assert.Nil(t, ctx) - assert.Contains(t, err.Error(), "task is required") -} - -func TestGatherRelatedFiles_Good(t *testing.T) { - tmpDir := t.TempDir() - - // Create test files - files := map[string]string{ - "app.go": "package app\n\nfunc Run() {}\n", - "config.ts": "export const config = {};\n", - "README.md": "# Project\n", - } - - for name, content := range files { - path := filepath.Join(tmpDir, name) - err := os.WriteFile(path, []byte(content), 0644) - require.NoError(t, err) - } - - task := &Task{ - ID: "task-1", - Title: "Test", - Files: []string{"app.go", "config.ts"}, - } - - gathered, err := GatherRelatedFiles(task, tmpDir) - require.NoError(t, err) - assert.Len(t, gathered, 2) - - // Check languages detected correctly - foundGo := false - foundTS := false - for _, f := range gathered { - if f.Path == "app.go" { - foundGo = true - assert.Equal(t, "go", f.Language) - } - if f.Path == "config.ts" { - foundTS = true - assert.Equal(t, "typescript", f.Language) - } - } - assert.True(t, foundGo, "should find app.go") - assert.True(t, foundTS, "should find config.ts") -} - -func TestGatherRelatedFiles_Bad_NilTask(t *testing.T) { - files, err := GatherRelatedFiles(nil, ".") - assert.Error(t, err) - assert.Nil(t, files) -} - -func TestGatherRelatedFiles_Good_MissingFiles(t *testing.T) { - tmpDir := t.TempDir() - - task := &Task{ - ID: "task-1", - Title: "Test", - Files: []string{"nonexistent.go", "also-missing.ts"}, - } - - // Should not error, just return empty list - gathered, err := GatherRelatedFiles(task, tmpDir) - require.NoError(t, err) - assert.Empty(t, gathered) -} - -func TestDetectLanguage(t *testing.T) { - tests := []struct { - path string - expected string - }{ - {"main.go", "go"}, - {"app.ts", "typescript"}, - {"app.tsx", "typescript"}, - {"script.js", "javascript"}, - {"script.jsx", "javascript"}, - {"main.py", "python"}, - {"lib.rs", "rust"}, - {"App.java", "java"}, - {"config.yaml", "yaml"}, - {"config.yml", "yaml"}, - {"data.json", "json"}, - {"index.html", "html"}, - {"styles.css", "css"}, - {"styles.scss", "scss"}, - {"query.sql", "sql"}, - {"README.md", "markdown"}, - {"unknown.xyz", "text"}, - {"", "text"}, - } - - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - result := detectLanguage(tt.path) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestExtractKeywords(t *testing.T) { - tests := []struct { - name string - text string - expected int // minimum number of keywords expected - }{ - { - name: "simple title", - text: "Add user authentication feature", - expected: 2, - }, - { - name: "with stop words", - text: "The quick brown fox jumps over the lazy dog", - expected: 3, - }, - { - name: "technical text", - text: "Implement OAuth2 authentication with JWT tokens", - expected: 3, - }, - { - name: "empty", - text: "", - expected: 0, - }, - { - name: "only stop words", - text: "the a an and or but in on at", - expected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - keywords := extractKeywords(tt.text) - assert.GreaterOrEqual(t, len(keywords), tt.expected) - // Keywords should not exceed 5 - assert.LessOrEqual(t, len(keywords), 5) - }) - } -} - -func TestTaskContext_FormatContext(t *testing.T) { - task := &Task{ - ID: "test-456", - Title: "Test Formatting", - Description: "This is a test description", - Priority: PriorityHigh, - Status: StatusInProgress, - } - - ctx := &TaskContext{ - Task: task, - Files: []FileContent{{Path: "main.go", Content: "package main", Language: "go"}}, - GitStatus: " M main.go", - RecentCommits: "abc123 Initial commit", - RelatedCode: []FileContent{{Path: "util.go", Content: "package util", Language: "go"}}, - } - - formatted := ctx.FormatContext() - - assert.Contains(t, formatted, "# Task Context") - assert.Contains(t, formatted, "test-456") - assert.Contains(t, formatted, "Test Formatting") - assert.Contains(t, formatted, "## Task Files") - assert.Contains(t, formatted, "## Git Status") - assert.Contains(t, formatted, "## Recent Commits") - assert.Contains(t, formatted, "## Related Code") -} diff --git a/pkg/lifecycle/coverage_boost_test.go b/pkg/lifecycle/coverage_boost_test.go deleted file mode 100644 index 849333c..0000000 --- a/pkg/lifecycle/coverage_boost_test.go +++ /dev/null @@ -1,757 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - - "forge.lthn.ai/core/go/pkg/core" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// ============================================================================ -// service.go — NewService, OnStartup, handleTask, doCommit, doPrompt -// ============================================================================ - -func TestNewService_Good(t *testing.T) { - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - factory := NewService(opts) - - result, err := factory(c) - require.NoError(t, err) - require.NotNil(t, result) - - svc, ok := result.(*Service) - require.True(t, ok, "factory should return *Service") - assert.Equal(t, opts.DefaultTools, svc.Opts().DefaultTools) - assert.Equal(t, opts.AllowEdit, svc.Opts().AllowEdit) -} - -func TestNewService_Good_CustomOpts(t *testing.T) { - c, err := core.New() - require.NoError(t, err) - - opts := ServiceOptions{ - DefaultTools: []string{"Bash", "Read", "Write", "Edit"}, - AllowEdit: true, - } - factory := NewService(opts) - - result, err := factory(c) - require.NoError(t, err) - - svc := result.(*Service) - assert.True(t, svc.Opts().AllowEdit) - assert.Len(t, svc.Opts().DefaultTools, 4) -} - -func TestOnStartup_Good(t *testing.T) { - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - err = svc.OnStartup(context.Background()) - assert.NoError(t, err) -} - -// mockClaude creates a mock "claude" binary that exits with code 1 and -// prepends its directory to PATH, restoring PATH when the test finishes. -func mockClaude(t *testing.T) { - t.Helper() - mockBin := filepath.Join(t.TempDir(), "claude") - err := os.WriteFile(mockBin, []byte("#!/bin/sh\nexit 1\n"), 0755) - require.NoError(t, err) - - origPath := os.Getenv("PATH") - t.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath) -} - -func TestHandleTask_Good_TaskCommit(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskCommit{ - Path: t.TempDir(), - Name: "test", - CanEdit: false, - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled, "TaskCommit should be handled") - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_TaskCommitCanEdit(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskCommit{ - Path: t.TempDir(), - Name: "test-edit", - CanEdit: true, - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled) - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_TaskPrompt(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskPrompt{ - Prompt: "test prompt", - WorkDir: t.TempDir(), - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled, "TaskPrompt should be handled") - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_TaskPromptWithTaskID(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskPrompt{ - Prompt: "test prompt", - WorkDir: t.TempDir(), - taskID: "task-123", - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled) - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_TaskPromptWithCustomTools(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskPrompt{ - Prompt: "test prompt", - WorkDir: t.TempDir(), - AllowedTools: []string{"Bash", "Read"}, - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled) - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_TaskPromptEmptyDefaultTools(t *testing.T) { - mockClaude(t) - - c, err := core.New() - require.NoError(t, err) - - opts := ServiceOptions{ - DefaultTools: nil, // empty tools - } - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - task := TaskPrompt{ - Prompt: "test prompt", - WorkDir: t.TempDir(), - } - - result, handled, err := svc.handleTask(c, task) - assert.Nil(t, result) - assert.True(t, handled) - assert.Error(t, err, "mock claude should exit 1") -} - -func TestHandleTask_Good_UnknownTask(t *testing.T) { - c, err := core.New() - require.NoError(t, err) - - opts := DefaultServiceOptions() - svc := &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - } - - // A string is not a recognised task type. - result, handled, err := svc.handleTask(c, "unknown-task") - assert.Nil(t, result) - assert.False(t, handled, "unknown task should not be handled") - assert.NoError(t, err) -} - -// ============================================================================ -// completion.go — CreatePR full coverage -// ============================================================================ - -func TestCreatePR_Good_WithGhMock(t *testing.T) { - dir := initGitRepo(t) - - // Create a script that pretends to be "gh" - mockBin := filepath.Join(t.TempDir(), "gh") - mockScript := `#!/bin/sh -echo "https://github.com/owner/repo/pull/42" -` - err := os.WriteFile(mockBin, []byte(mockScript), 0755) - require.NoError(t, err) - - // Prepend mock bin directory to PATH - origPath := os.Getenv("PATH") - _ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath) - defer func() { _ = os.Setenv("PATH", origPath) }() - - task := &Task{ - ID: "PR-10", - Title: "Test PR", - Description: "Test PR description", - Priority: PriorityMedium, - } - - prURL, err := CreatePR(context.Background(), task, dir, PROptions{}) - require.NoError(t, err) - assert.Equal(t, "https://github.com/owner/repo/pull/42", prURL) -} - -func TestCreatePR_Good_WithAllOptions(t *testing.T) { - dir := initGitRepo(t) - - mockBin := filepath.Join(t.TempDir(), "gh") - mockScript := `#!/bin/sh -echo "https://github.com/owner/repo/pull/99" -` - err := os.WriteFile(mockBin, []byte(mockScript), 0755) - require.NoError(t, err) - - origPath := os.Getenv("PATH") - _ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath) - defer func() { _ = os.Setenv("PATH", origPath) }() - - task := &Task{ - ID: "PR-20", - Title: "Full options PR", - Description: "All options test", - Priority: PriorityHigh, - Labels: []string{"enhancement", "v2"}, - } - - opts := PROptions{ - Title: "Custom title", - Body: "Custom body", - Draft: true, - Labels: []string{"enhancement", "v2"}, - Base: "develop", - } - - prURL, err := CreatePR(context.Background(), task, dir, opts) - require.NoError(t, err) - assert.Equal(t, "https://github.com/owner/repo/pull/99", prURL) -} - -func TestCreatePR_Good_DefaultTitleAndBody(t *testing.T) { - dir := initGitRepo(t) - - mockBin := filepath.Join(t.TempDir(), "gh") - mockScript := `#!/bin/sh -echo "https://github.com/owner/repo/pull/55" -` - err := os.WriteFile(mockBin, []byte(mockScript), 0755) - require.NoError(t, err) - - origPath := os.Getenv("PATH") - _ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath) - defer func() { _ = os.Setenv("PATH", origPath) }() - - task := &Task{ - ID: "PR-30", - Title: "Default title from task", - Description: "Default body from task description", - Priority: PriorityCritical, - } - - // Empty PROptions — title and body should default from task. - prURL, err := CreatePR(context.Background(), task, dir, PROptions{}) - require.NoError(t, err) - assert.Contains(t, prURL, "pull/55") -} - -func TestCreatePR_Bad_GhFails(t *testing.T) { - dir := initGitRepo(t) - - mockBin := filepath.Join(t.TempDir(), "gh") - mockScript := `#!/bin/sh -echo "error: not logged in" >&2 -exit 1 -` - err := os.WriteFile(mockBin, []byte(mockScript), 0755) - require.NoError(t, err) - - origPath := os.Getenv("PATH") - _ = os.Setenv("PATH", filepath.Dir(mockBin)+":"+origPath) - defer func() { _ = os.Setenv("PATH", origPath) }() - - task := &Task{ - ID: "PR-40", - Title: "Failing PR", - } - - prURL, err := CreatePR(context.Background(), task, dir, PROptions{}) - assert.Error(t, err) - assert.Empty(t, prURL) - assert.Contains(t, err.Error(), "failed to create PR") -} - -func TestCreatePR_Bad_GhNotFound(t *testing.T) { - dir := initGitRepo(t) - - // Set PATH to an empty directory so "gh" is not found. - emptyDir := t.TempDir() - origPath := os.Getenv("PATH") - _ = os.Setenv("PATH", emptyDir) - defer func() { _ = os.Setenv("PATH", origPath) }() - - task := &Task{ - ID: "PR-50", - Title: "No gh binary", - } - - prURL, err := CreatePR(context.Background(), task, dir, PROptions{}) - assert.Error(t, err) - assert.Empty(t, prURL) -} - -// ============================================================================ -// completion.go — PushChanges success path -// ============================================================================ - -func TestPushChanges_Good_WithRemote(t *testing.T) { - // Create a bare remote repo and a working repo that pushes to it. - remoteDir := t.TempDir() - _, err := runCommandCtx(context.Background(), remoteDir, "git", "init", "--bare") - require.NoError(t, err) - - dir := initGitRepo(t) - _, err = runCommandCtx(context.Background(), dir, "git", "remote", "add", "origin", remoteDir) - require.NoError(t, err) - - // Push the initial commit to set up upstream. - branch, err := GetCurrentBranch(context.Background(), dir) - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "push", "-u", "origin", branch) - require.NoError(t, err) - - // Create a new commit to push. - err = os.WriteFile(filepath.Join(dir, "push-test.txt"), []byte("push me\n"), 0644) - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "add", "-A") - require.NoError(t, err) - _, err = runCommandCtx(context.Background(), dir, "git", "commit", "-m", "push test") - require.NoError(t, err) - - err = PushChanges(context.Background(), dir) - assert.NoError(t, err) -} - -// ============================================================================ -// completion.go — CreateBranch in non-git dir -// ============================================================================ - -func TestCreateBranch_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - task := &Task{ - ID: "BR-99", - Title: "Not a repo", - } - - branchName, err := CreateBranch(context.Background(), task, dir) - assert.Error(t, err) - assert.Empty(t, branchName) - assert.Contains(t, err.Error(), "failed to create branch") -} - -// ============================================================================ -// config.go — SaveConfig error paths, ConfigPath -// ============================================================================ - -func TestSaveConfig_Good_CreatesConfigDir(t *testing.T) { - tmpHome := t.TempDir() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - cfg := &Config{ - BaseURL: "https://test.example.com", - Token: "test-token-123", - } - - err := SaveConfig(cfg) - require.NoError(t, err) - - // Verify .core directory was created. - info, err := os.Stat(filepath.Join(tmpHome, ".core")) - require.NoError(t, err) - assert.True(t, info.IsDir()) -} - -func TestSaveConfig_Good_OverwritesExisting(t *testing.T) { - tmpHome := t.TempDir() - - originalHome := os.Getenv("HOME") - _ = os.Setenv("HOME", tmpHome) - defer func() { _ = os.Setenv("HOME", originalHome) }() - - // Write first config. - cfg1 := &Config{Token: "first-token"} - err := SaveConfig(cfg1) - require.NoError(t, err) - - // Overwrite with second config. - cfg2 := &Config{Token: "second-token"} - err = SaveConfig(cfg2) - require.NoError(t, err) - - // Verify second config is saved. - data, err := os.ReadFile(filepath.Join(tmpHome, ".core", "agentic.yaml")) - require.NoError(t, err) - assert.Contains(t, string(data), "second-token") - assert.NotContains(t, string(data), "first-token") -} - -func TestConfigPath_Good_ContainsExpectedComponents(t *testing.T) { - path, err := ConfigPath() - require.NoError(t, err) - - // Path should end with .core/agentic.yaml. - assert.True(t, filepath.IsAbs(path), "path should be absolute") - dir, file := filepath.Split(path) - assert.Equal(t, "agentic.yaml", file) - assert.Contains(t, dir, ".core") -} - -// ============================================================================ -// allowance_service.go — ResetAgent error path -// ============================================================================ - -// resetErrorStore extends errorStore with a ResetUsage failure mode. -type resetErrorStore struct { - *MemoryStore - failReset bool -} - -func (e *resetErrorStore) ResetUsage(agentID string) error { - if e.failReset { - return errors.New("simulated ResetUsage error") - } - return e.MemoryStore.ResetUsage(agentID) -} - -func TestResetAgent_Bad_StoreError(t *testing.T) { - store := &resetErrorStore{ - MemoryStore: NewMemoryStore(), - failReset: true, - } - svc := NewAllowanceService(store) - - err := svc.ResetAgent("agent-1") - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to reset usage") -} - -func TestResetAgent_Good_Success(t *testing.T) { - store := &resetErrorStore{ - MemoryStore: NewMemoryStore(), - failReset: false, - } - svc := NewAllowanceService(store) - - // Pre-populate some usage. - _ = store.IncrementUsage("agent-1", 5000, 3) - - err := svc.ResetAgent("agent-1") - require.NoError(t, err) - - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.JobsStarted) -} - -// ============================================================================ -// client.go — error paths for ListTasks, GetTask, ClaimTask, UpdateTask, -// CompleteTask, Ping -// ============================================================================ - -func TestClient_ListTasks_Bad_ConnectionRefused(t *testing.T) { - client := NewClient("http://127.0.0.1:1", "test-token") - client.HTTPClient.Timeout = 100 * 1000000 // 100ms in nanoseconds - - tasks, err := client.ListTasks(context.Background(), ListOptions{}) - assert.Error(t, err) - assert.Nil(t, tasks) - assert.Contains(t, err.Error(), "request failed") -} - -func TestClient_ListTasks_Bad_InvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("not valid json")) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - tasks, err := client.ListTasks(context.Background(), ListOptions{}) - assert.Error(t, err) - assert.Nil(t, tasks) - assert.Contains(t, err.Error(), "failed to decode response") -} - -func TestClient_GetTask_Bad_ConnectionRefused(t *testing.T) { - client := NewClient("http://127.0.0.1:1", "test-token") - client.HTTPClient.Timeout = 100 * 1000000 - - task, err := client.GetTask(context.Background(), "task-1") - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "request failed") -} - -func TestClient_GetTask_Bad_InvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{invalid")) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task, err := client.GetTask(context.Background(), "task-1") - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "failed to decode response") -} - -func TestClient_ClaimTask_Bad_ConnectionRefused(t *testing.T) { - client := NewClient("http://127.0.0.1:1", "test-token") - client.HTTPClient.Timeout = 100 * 1000000 - - task, err := client.ClaimTask(context.Background(), "task-1") - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "request failed") -} - -func TestClient_ClaimTask_Bad_InvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("completely broken json")) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - client.AgentID = "agent-1" - task, err := client.ClaimTask(context.Background(), "task-1") - assert.Error(t, err) - assert.Nil(t, task) - assert.Contains(t, err.Error(), "failed to decode response") -} - -func TestClient_UpdateTask_Bad_ConnectionRefused(t *testing.T) { - client := NewClient("http://127.0.0.1:1", "test-token") - client.HTTPClient.Timeout = 100 * 1000000 - - err := client.UpdateTask(context.Background(), "task-1", TaskUpdate{ - Status: StatusInProgress, - }) - assert.Error(t, err) - assert.Contains(t, err.Error(), "request failed") -} - -func TestClient_UpdateTask_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "server error"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.UpdateTask(context.Background(), "task-1", TaskUpdate{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "server error") -} - -func TestClient_CompleteTask_Bad_ConnectionRefused(t *testing.T) { - client := NewClient("http://127.0.0.1:1", "test-token") - client.HTTPClient.Timeout = 100 * 1000000 - - err := client.CompleteTask(context.Background(), "task-1", TaskResult{ - Success: true, - }) - assert.Error(t, err) - assert.Contains(t, err.Error(), "request failed") -} - -func TestClient_CompleteTask_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(APIError{Message: "bad request"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.CompleteTask(context.Background(), "task-1", TaskResult{}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "bad request") -} - -func TestClient_Ping_Bad_ServerReturns5xx(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - err := client.Ping(context.Background()) - assert.Error(t, err) - assert.Contains(t, err.Error(), "status 503") -} - -// ============================================================================ -// context.go — BuildTaskContext edge cases -// ============================================================================ - -func TestBuildTaskContext_Good_FilesGatherError(t *testing.T) { - // Task with files but in a non-existent directory. - task := &Task{ - ID: "ctx-err-1", - Title: "Files error test", - Files: []string{"nonexistent.go"}, - } - - dir := t.TempDir() - ctx, err := BuildTaskContext(task, dir) - require.NoError(t, err, "BuildTaskContext should not fail even if files are missing") - assert.NotNil(t, ctx) - assert.Empty(t, ctx.Files, "no files should be gathered") -} - -// ============================================================================ -// completion.go — generateBranchName edge cases -// ============================================================================ - -func TestGenerateBranchName_Good_TestsLabel(t *testing.T) { - task := &Task{ - ID: "GEN-1", - Title: "Add tests for core", - Labels: []string{"tests"}, - } - name := generateBranchName(task) - assert.Equal(t, "test/GEN-1-add-tests-for-core", name) -} - -func TestGenerateBranchName_Good_EmptyTitle(t *testing.T) { - task := &Task{ - ID: "GEN-2", - Title: "", - Labels: nil, - } - name := generateBranchName(task) - assert.Equal(t, "feat/GEN-2-", name) -} - -func TestGenerateBranchName_Good_BugfixLabel(t *testing.T) { - task := &Task{ - ID: "GEN-3", - Title: "Fix memory leak", - Labels: []string{"bugfix"}, - } - name := generateBranchName(task) - assert.Equal(t, "fix/GEN-3-fix-memory-leak", name) -} - -func TestGenerateBranchName_Good_DocsLabel(t *testing.T) { - task := &Task{ - ID: "GEN-4", - Title: "Update docs", - Labels: []string{"docs"}, - } - name := generateBranchName(task) - assert.Equal(t, "docs/GEN-4-update-docs", name) -} - -func TestGenerateBranchName_Good_FixLabel(t *testing.T) { - task := &Task{ - ID: "GEN-5", - Title: "Fix something", - Labels: []string{"fix"}, - } - name := generateBranchName(task) - assert.Equal(t, "fix/GEN-5-fix-something", name) -} - -// ============================================================================ -// AutoCommit additional edge cases -// ============================================================================ - -func TestAutoCommit_Bad_NotAGitRepo(t *testing.T) { - dir := t.TempDir() - - task := &Task{ID: "AC-1", Title: "Not a repo"} - err := AutoCommit(context.Background(), task, dir, "feat: test") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to stage changes") -} diff --git a/pkg/lifecycle/dispatcher.go b/pkg/lifecycle/dispatcher.go deleted file mode 100644 index f220f49..0000000 --- a/pkg/lifecycle/dispatcher.go +++ /dev/null @@ -1,259 +0,0 @@ -package lifecycle - -import ( - "cmp" - "context" - "slices" - "time" - - "forge.lthn.ai/core/go-log" -) - -const ( - // DefaultMaxRetries is the default number of dispatch attempts before dead-lettering. - DefaultMaxRetries = 3 - // baseBackoff is the base duration for exponential backoff between retries. - baseBackoff = 5 * time.Second -) - -// Dispatcher orchestrates task dispatch by combining the agent registry, -// task router, allowance service, and API client. -type Dispatcher struct { - registry AgentRegistry - router TaskRouter - allowance *AllowanceService - client *Client // can be nil for tests - events EventEmitter -} - -// NewDispatcher creates a new Dispatcher with the given dependencies. -func NewDispatcher(registry AgentRegistry, router TaskRouter, allowance *AllowanceService, client *Client) *Dispatcher { - return &Dispatcher{ - registry: registry, - router: router, - allowance: allowance, - client: client, - } -} - -// SetEventEmitter attaches an event emitter to the dispatcher for lifecycle notifications. -func (d *Dispatcher) SetEventEmitter(em EventEmitter) { - d.events = em -} - -// emit is a convenience helper that publishes an event if an emitter is set. -func (d *Dispatcher) emit(ctx context.Context, event Event) { - if d.events != nil { - if event.Timestamp.IsZero() { - event.Timestamp = time.Now().UTC() - } - _ = d.events.Emit(ctx, event) - } -} - -// Dispatch assigns a task to the best available agent. It queries the registry -// for available agents, routes the task, checks the agent's allowance, claims -// the task via the API client (if present), and records usage. Returns the -// assigned agent ID. -func (d *Dispatcher) Dispatch(ctx context.Context, task *Task) (string, error) { - const op = "Dispatcher.Dispatch" - - // 1. Get available agents from registry. - agents := d.registry.List() - - // 2. Route task to best agent. - agentID, err := d.router.Route(task, agents) - if err != nil { - d.emit(ctx, Event{ - Type: EventDispatchFailedNoAgent, - TaskID: task.ID, - }) - return "", log.E(op, "routing failed", err) - } - - // 3. Check allowance for the selected agent. - check, err := d.allowance.Check(agentID, "") - if err != nil { - return "", log.E(op, "allowance check failed", err) - } - if !check.Allowed { - d.emit(ctx, Event{ - Type: EventDispatchFailedQuota, - TaskID: task.ID, - AgentID: agentID, - Payload: check.Reason, - }) - return "", log.E(op, "agent quota exceeded: "+check.Reason, nil) - } - - // 4. Claim the task via the API client (if available). - if d.client != nil { - if _, err := d.client.ClaimTask(ctx, task.ID); err != nil { - return "", log.E(op, "failed to claim task", err) - } - d.emit(ctx, Event{ - Type: EventTaskClaimed, - TaskID: task.ID, - AgentID: agentID, - }) - } - - // 5. Record job start usage. - if err := d.allowance.RecordUsage(UsageReport{ - AgentID: agentID, - JobID: task.ID, - Event: QuotaEventJobStarted, - Timestamp: time.Now().UTC(), - }); err != nil { - return "", log.E(op, "failed to record usage", err) - } - - d.emit(ctx, Event{ - Type: EventTaskDispatched, - TaskID: task.ID, - AgentID: agentID, - }) - - return agentID, nil -} - -// priorityRank maps a TaskPriority to a numeric rank for sorting. -// Lower values are dispatched first. -func priorityRank(p TaskPriority) int { - switch p { - case PriorityCritical: - return 0 - case PriorityHigh: - return 1 - case PriorityMedium: - return 2 - case PriorityLow: - return 3 - default: - return 4 - } -} - -// sortTasksByPriority sorts tasks by priority (Critical first) then by -// CreatedAt (oldest first) as a tie-breaker. Uses slices.SortStableFunc for determinism. -func sortTasksByPriority(tasks []Task) { - slices.SortStableFunc(tasks, func(a, b Task) int { - ri, rj := priorityRank(a.Priority), priorityRank(b.Priority) - if ri != rj { - return cmp.Compare(ri, rj) - } - if a.CreatedAt.Before(b.CreatedAt) { - return -1 - } - if a.CreatedAt.After(b.CreatedAt) { - return 1 - } - return 0 - }) -} - -// backoffDuration returns the exponential backoff duration for the given retry -// count. First retry waits baseBackoff (5s), second waits 10s, third 20s, etc. -func backoffDuration(retryCount int) time.Duration { - if retryCount <= 0 { - return 0 - } - d := baseBackoff - for range retryCount - 1 { - d *= 2 - } - return d -} - -// shouldSkipRetry returns true if a task has been retried and the backoff -// period has not yet elapsed since the last attempt. -func shouldSkipRetry(task *Task, now time.Time) bool { - if task.RetryCount <= 0 { - return false - } - if task.LastAttempt == nil { - return false - } - return task.LastAttempt.Add(backoffDuration(task.RetryCount)).After(now) -} - -// effectiveMaxRetries returns the max retries for a task, using DefaultMaxRetries -// when the task does not specify one. -func effectiveMaxRetries(task *Task) int { - if task.MaxRetries > 0 { - return task.MaxRetries - } - return DefaultMaxRetries -} - -// DispatchLoop polls for pending tasks at the given interval and dispatches -// each one. Tasks are sorted by priority (Critical > High > Medium > Low) with -// oldest-first tie-breaking. Failed dispatches are retried with exponential -// backoff. Tasks exceeding their retry limit are dead-lettered with StatusFailed. -// It runs until the context is cancelled and returns ctx.Err(). -func (d *Dispatcher) DispatchLoop(ctx context.Context, interval time.Duration) error { - const op = "Dispatcher.DispatchLoop" - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - if d.client == nil { - continue - } - - tasks, err := d.client.ListTasks(ctx, ListOptions{Status: StatusPending}) - if err != nil { - // Log but continue — transient API errors should not stop the loop. - _ = log.E(op, "failed to list pending tasks", err) - continue - } - - // Sort by priority then by creation time. - sortTasksByPriority(tasks) - - now := time.Now().UTC() - for i := range tasks { - if ctx.Err() != nil { - return ctx.Err() - } - - task := &tasks[i] - - // Check if backoff period has not elapsed for retried tasks. - if shouldSkipRetry(task, now) { - continue - } - - if _, err := d.Dispatch(ctx, task); err != nil { - // Increment retry count and record the attempt time. - task.RetryCount++ - attemptTime := now - task.LastAttempt = &attemptTime - - maxRetries := effectiveMaxRetries(task) - if task.RetryCount >= maxRetries { - // Dead-letter: mark as failed via the API. - if updateErr := d.client.UpdateTask(ctx, task.ID, TaskUpdate{ - Status: StatusFailed, - Notes: "max retries exceeded", - }); updateErr != nil { - _ = log.E(op, "failed to dead-letter task "+task.ID, updateErr) - } - d.emit(ctx, Event{ - Type: EventTaskDeadLettered, - TaskID: task.ID, - Payload: "max retries exceeded", - }) - } else { - _ = log.E(op, "failed to dispatch task "+task.ID, err) - } - } - } - } - } -} diff --git a/pkg/lifecycle/dispatcher_test.go b/pkg/lifecycle/dispatcher_test.go deleted file mode 100644 index fe45686..0000000 --- a/pkg/lifecycle/dispatcher_test.go +++ /dev/null @@ -1,578 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// setupDispatcher creates a Dispatcher with a memory registry, default router, -// and memory allowance store, pre-loaded with agents and allowances. -func setupDispatcher(t *testing.T, client *Client) (*Dispatcher, *MemoryRegistry, *MemoryStore) { - t.Helper() - - reg := NewMemoryRegistry() - router := NewDefaultRouter() - store := NewMemoryStore() - svc := NewAllowanceService(store) - - d := NewDispatcher(reg, router, svc, client) - return d, reg, store -} - -func registerAgent(t *testing.T, reg *MemoryRegistry, store *MemoryStore, id string, caps []string, maxLoad int) { - t.Helper() - _ = reg.Register(AgentInfo{ - ID: id, - Name: id, - Capabilities: caps, - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - MaxLoad: maxLoad, - }) - _ = store.SetAllowance(&AgentAllowance{ - AgentID: id, - DailyTokenLimit: 100000, - DailyJobLimit: 50, - ConcurrentJobs: 5, - }) -} - -// --- Dispatch tests --- - -func TestDispatcher_Dispatch_Good_NilClient(t *testing.T) { - d, reg, store := setupDispatcher(t, nil) - registerAgent(t, reg, store, "agent-1", []string{"go"}, 5) - - task := &Task{ID: "task-1", Labels: []string{"go"}, Priority: PriorityMedium} - agentID, err := d.Dispatch(context.Background(), task) - require.NoError(t, err) - assert.Equal(t, "agent-1", agentID) - - // Verify usage was recorded. - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 1, usage.JobsStarted) - assert.Equal(t, 1, usage.ActiveJobs) -} - -func TestDispatcher_Dispatch_Good_WithHTTPClient(t *testing.T) { - claimedTask := Task{ID: "task-1", Status: StatusInProgress, ClaimedBy: "agent-1"} - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost && r.URL.Path == "/api/tasks/task-1/claim" { - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask}) - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 5) - - task := &Task{ID: "task-1", Priority: PriorityHigh} - agentID, err := d.Dispatch(context.Background(), task) - require.NoError(t, err) - assert.Equal(t, "agent-1", agentID) - - // Verify usage recorded. - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 1, usage.JobsStarted) -} - -func TestDispatcher_Dispatch_Good_PicksBestAgent(t *testing.T) { - d, reg, store := setupDispatcher(t, nil) - registerAgent(t, reg, store, "heavy", []string{"go"}, 5) - registerAgent(t, reg, store, "light", []string{"go"}, 5) - - // Give "heavy" some load. - _ = reg.Register(AgentInfo{ - ID: "heavy", - Name: "heavy", - Capabilities: []string{"go"}, - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - CurrentLoad: 4, - MaxLoad: 5, - }) - - task := &Task{ID: "task-1", Labels: []string{"go"}, Priority: PriorityMedium} - agentID, err := d.Dispatch(context.Background(), task) - require.NoError(t, err) - assert.Equal(t, "light", agentID) // light has score 1.0, heavy has 0.2 -} - -func TestDispatcher_Dispatch_Bad_NoAgents(t *testing.T) { - d, _, _ := setupDispatcher(t, nil) - - task := &Task{ID: "task-1", Priority: PriorityMedium} - _, err := d.Dispatch(context.Background(), task) - require.Error(t, err) -} - -func TestDispatcher_Dispatch_Bad_AllowanceExceeded(t *testing.T) { - d, reg, store := setupDispatcher(t, nil) - registerAgent(t, reg, store, "agent-1", nil, 5) - - // Exhaust the agent's daily job limit. - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyJobLimit: 1, - }) - _ = store.IncrementUsage("agent-1", 0, 1) - - task := &Task{ID: "task-1", Priority: PriorityMedium} - _, err := d.Dispatch(context.Background(), task) - require.Error(t, err) - assert.Contains(t, err.Error(), "quota exceeded") -} - -func TestDispatcher_Dispatch_Bad_ClaimFails(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusConflict) - _ = json.NewEncoder(w).Encode(APIError{Code: 409, Message: "already claimed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 5) - - task := &Task{ID: "task-1", Priority: PriorityMedium} - _, err := d.Dispatch(context.Background(), task) - require.Error(t, err) - assert.Contains(t, err.Error(), "claim task") - - // Verify usage was NOT recorded when claim fails. - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 0, usage.JobsStarted) -} - -// --- DispatchLoop tests --- - -func TestDispatcher_DispatchLoop_Good_Cancellation(t *testing.T) { - d, _, _ := setupDispatcher(t, nil) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately. - - err := d.DispatchLoop(ctx, 100*time.Millisecond) - require.ErrorIs(t, err, context.Canceled) -} - -func TestDispatcher_DispatchLoop_Good_DispatchesPendingTasks(t *testing.T) { - pendingTasks := []Task{ - {ID: "task-1", Status: StatusPending, Priority: PriorityMedium}, - {ID: "task-2", Status: StatusPending, Priority: PriorityHigh}, - } - - var mu sync.Mutex - claimedIDs := make(map[string]bool) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/tasks": - w.Header().Set("Content-Type", "application/json") - mu.Lock() - // Return only tasks not yet claimed. - var remaining []Task - for _, t := range pendingTasks { - if !claimedIDs[t.ID] { - remaining = append(remaining, t) - } - } - mu.Unlock() - _ = json.NewEncoder(w).Encode(remaining) - - case r.Method == http.MethodPost: - // Extract task ID from claim URL. - w.Header().Set("Content-Type", "application/json") - // Parse the task ID from the path. - for _, t := range pendingTasks { - if r.URL.Path == "/api/tasks/"+t.ID+"/claim" { - mu.Lock() - claimedIDs[t.ID] = true - mu.Unlock() - claimed := t - claimed.Status = StatusInProgress - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed}) - return - } - } - w.WriteHeader(http.StatusNotFound) - - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 10) - - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() - - err := d.DispatchLoop(ctx, 50*time.Millisecond) - require.ErrorIs(t, err, context.DeadlineExceeded) - - // Verify tasks were claimed. - mu.Lock() - defer mu.Unlock() - assert.True(t, claimedIDs["task-1"]) - assert.True(t, claimedIDs["task-2"]) -} - -func TestDispatcher_DispatchLoop_Good_NilClientSkipsTick(t *testing.T) { - d, _, _ := setupDispatcher(t, nil) - - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() - - err := d.DispatchLoop(ctx, 50*time.Millisecond) - require.ErrorIs(t, err, context.DeadlineExceeded) - // No panics — nil client is handled gracefully. -} - -// --- Concurrent dispatch --- - -func TestDispatcher_Dispatch_Good_Concurrent(t *testing.T) { - d, reg, store := setupDispatcher(t, nil) - registerAgent(t, reg, store, "agent-1", nil, 0) - // Override allowance to truly unlimited (registerAgent hardcodes ConcurrentJobs: 5) - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyJobLimit: 100, - ConcurrentJobs: 0, // 0 = unlimited - }) - - var wg sync.WaitGroup - for i := range 10 { - wg.Add(1) - go func(n int) { - defer wg.Done() - task := &Task{ID: "task-" + string(rune('a'+n)), Priority: PriorityMedium} - _, _ = d.Dispatch(context.Background(), task) - }(i) - } - wg.Wait() - - // Verify usage was recorded for all dispatches. - usage, _ := store.GetUsage("agent-1") - assert.Equal(t, 10, usage.JobsStarted) -} - -// --- Phase 7: Priority sorting tests --- - -func TestSortTasksByPriority_Good(t *testing.T) { - base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) - tasks := []Task{ - {ID: "low-old", Priority: PriorityLow, CreatedAt: base}, - {ID: "critical-new", Priority: PriorityCritical, CreatedAt: base.Add(2 * time.Hour)}, - {ID: "medium-old", Priority: PriorityMedium, CreatedAt: base}, - {ID: "high-old", Priority: PriorityHigh, CreatedAt: base}, - {ID: "critical-old", Priority: PriorityCritical, CreatedAt: base}, - } - - sortTasksByPriority(tasks) - - // Critical tasks first, oldest critical before newer critical. - assert.Equal(t, "critical-old", tasks[0].ID) - assert.Equal(t, "critical-new", tasks[1].ID) - // Then high. - assert.Equal(t, "high-old", tasks[2].ID) - // Then medium. - assert.Equal(t, "medium-old", tasks[3].ID) - // Then low. - assert.Equal(t, "low-old", tasks[4].ID) -} - -// --- Phase 7: Backoff duration tests --- - -func TestBackoffDuration_Good(t *testing.T) { - // retryCount=0 → 0 (no backoff). - assert.Equal(t, time.Duration(0), backoffDuration(0)) - // retryCount=1 → 5s (base). - assert.Equal(t, 5*time.Second, backoffDuration(1)) - // retryCount=2 → 10s. - assert.Equal(t, 10*time.Second, backoffDuration(2)) - // retryCount=3 → 20s. - assert.Equal(t, 20*time.Second, backoffDuration(3)) - // retryCount=4 → 40s. - assert.Equal(t, 40*time.Second, backoffDuration(4)) -} - -// --- Phase 7: shouldSkipRetry tests --- - -func TestShouldSkipRetry_Good(t *testing.T) { - now := time.Now().UTC() - recent := now.Add(-2 * time.Second) // 2s ago, backoff for retry 1 is 5s → skip. - - task := &Task{ - ID: "task-1", - RetryCount: 1, - LastAttempt: &recent, - } - assert.True(t, shouldSkipRetry(task, now)) - - // After backoff elapses, should NOT skip. - old := now.Add(-10 * time.Second) // 10s ago, backoff for retry 1 is 5s → ready. - task.LastAttempt = &old - assert.False(t, shouldSkipRetry(task, now)) -} - -func TestShouldSkipRetry_Bad_NoRetry(t *testing.T) { - now := time.Now().UTC() - - // RetryCount=0 → never skip. - task := &Task{ID: "task-1", RetryCount: 0} - assert.False(t, shouldSkipRetry(task, now)) - - // RetryCount=0 even with a LastAttempt set → never skip. - recent := now.Add(-1 * time.Second) - task.LastAttempt = &recent - assert.False(t, shouldSkipRetry(task, now)) - - // RetryCount>0 but nil LastAttempt → never skip. - task2 := &Task{ID: "task-2", RetryCount: 2, LastAttempt: nil} - assert.False(t, shouldSkipRetry(task2, now)) -} - -// --- Phase 7: DispatchLoop priority order test --- - -func TestDispatcher_DispatchLoop_Good_PriorityOrder(t *testing.T) { - base := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) - pendingTasks := []Task{ - {ID: "low-1", Status: StatusPending, Priority: PriorityLow, CreatedAt: base}, - {ID: "critical-1", Status: StatusPending, Priority: PriorityCritical, CreatedAt: base}, - {ID: "medium-1", Status: StatusPending, Priority: PriorityMedium, CreatedAt: base}, - {ID: "high-1", Status: StatusPending, Priority: PriorityHigh, CreatedAt: base}, - {ID: "critical-2", Status: StatusPending, Priority: PriorityCritical, CreatedAt: base.Add(time.Second)}, - } - - var mu sync.Mutex - var claimOrder []string - claimedIDs := make(map[string]bool) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/tasks": - w.Header().Set("Content-Type", "application/json") - mu.Lock() - var remaining []Task - for _, tk := range pendingTasks { - if !claimedIDs[tk.ID] { - remaining = append(remaining, tk) - } - } - mu.Unlock() - _ = json.NewEncoder(w).Encode(remaining) - - case r.Method == http.MethodPost: - w.Header().Set("Content-Type", "application/json") - for _, tk := range pendingTasks { - if r.URL.Path == "/api/tasks/"+tk.ID+"/claim" { - mu.Lock() - claimedIDs[tk.ID] = true - claimOrder = append(claimOrder, tk.ID) - mu.Unlock() - claimed := tk - claimed.Status = StatusInProgress - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed}) - return - } - } - w.WriteHeader(http.StatusNotFound) - - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 10) - - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() - - err := d.DispatchLoop(ctx, 50*time.Millisecond) - require.ErrorIs(t, err, context.DeadlineExceeded) - - mu.Lock() - defer mu.Unlock() - - // All 5 tasks should have been claimed. - require.Len(t, claimOrder, 5) - // Critical tasks first (oldest before newest), then high, medium, low. - assert.Equal(t, "critical-1", claimOrder[0]) - assert.Equal(t, "critical-2", claimOrder[1]) - assert.Equal(t, "high-1", claimOrder[2]) - assert.Equal(t, "medium-1", claimOrder[3]) - assert.Equal(t, "low-1", claimOrder[4]) -} - -// --- Phase 7: DispatchLoop retry backoff test --- - -func TestDispatcher_DispatchLoop_Good_RetryBackoff(t *testing.T) { - // A task with a recent LastAttempt and RetryCount=1 should be skipped - // because the backoff period (5s) has not elapsed. - recentAttempt := time.Now().UTC() - pendingTasks := []Task{ - { - ID: "retrying-task", - Status: StatusPending, - Priority: PriorityHigh, - CreatedAt: time.Now().UTC().Add(-time.Hour), - RetryCount: 1, - LastAttempt: &recentAttempt, - }, - { - ID: "fresh-task", - Status: StatusPending, - Priority: PriorityLow, - CreatedAt: time.Now().UTC(), - }, - } - - var mu sync.Mutex - claimedIDs := make(map[string]bool) - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/tasks": - w.Header().Set("Content-Type", "application/json") - mu.Lock() - var remaining []Task - for _, tk := range pendingTasks { - if !claimedIDs[tk.ID] { - remaining = append(remaining, tk) - } - } - mu.Unlock() - _ = json.NewEncoder(w).Encode(remaining) - - case r.Method == http.MethodPost: - w.Header().Set("Content-Type", "application/json") - for _, tk := range pendingTasks { - if r.URL.Path == "/api/tasks/"+tk.ID+"/claim" { - mu.Lock() - claimedIDs[tk.ID] = true - mu.Unlock() - claimed := tk - claimed.Status = StatusInProgress - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimed}) - return - } - } - w.WriteHeader(http.StatusNotFound) - - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 10) - - // Run the loop for a short period — not long enough for the 5s backoff to elapse. - ctx, cancel := context.WithTimeout(context.Background(), 300*time.Millisecond) - defer cancel() - - err := d.DispatchLoop(ctx, 50*time.Millisecond) - require.ErrorIs(t, err, context.DeadlineExceeded) - - mu.Lock() - defer mu.Unlock() - - // The fresh task should have been claimed. - assert.True(t, claimedIDs["fresh-task"]) - // The retrying task should NOT have been claimed because backoff has not elapsed. - assert.False(t, claimedIDs["retrying-task"]) -} - -// --- Phase 7: DispatchLoop dead-letter test --- - -func TestDispatcher_DispatchLoop_Good_DeadLetter(t *testing.T) { - // A task with RetryCount at MaxRetries-1 that fails dispatch should be dead-lettered. - pendingTasks := []Task{ - { - ID: "doomed-task", - Status: StatusPending, - Priority: PriorityHigh, - CreatedAt: time.Now().UTC().Add(-time.Hour), - MaxRetries: 1, // Will fail after 1 attempt. - RetryCount: 0, - }, - } - - var mu sync.Mutex - var deadLettered bool - var deadLetterNotes string - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodGet && r.URL.Path == "/api/tasks": - w.Header().Set("Content-Type", "application/json") - mu.Lock() - done := deadLettered - mu.Unlock() - if done { - // Return empty list once dead-lettered. - _ = json.NewEncoder(w).Encode([]Task{}) - } else { - _ = json.NewEncoder(w).Encode(pendingTasks) - } - - case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/doomed-task/claim": - // Claim always fails to trigger retry logic. - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Code: 500, Message: "server error"}) - - case r.Method == http.MethodPatch && r.URL.Path == "/api/tasks/doomed-task": - // This is the UpdateTask call for dead-lettering. - var update TaskUpdate - _ = json.NewDecoder(r.Body).Decode(&update) - mu.Lock() - deadLettered = true - deadLetterNotes = update.Notes - mu.Unlock() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) - - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - d, reg, store := setupDispatcher(t, client) - registerAgent(t, reg, store, "agent-1", nil, 10) - - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() - - err := d.DispatchLoop(ctx, 50*time.Millisecond) - require.ErrorIs(t, err, context.DeadlineExceeded) - - mu.Lock() - defer mu.Unlock() - - assert.True(t, deadLettered, "task should have been dead-lettered") - assert.Equal(t, "max retries exceeded", deadLetterNotes) -} diff --git a/pkg/lifecycle/embed.go b/pkg/lifecycle/embed.go deleted file mode 100644 index 8c624f8..0000000 --- a/pkg/lifecycle/embed.go +++ /dev/null @@ -1,19 +0,0 @@ -package lifecycle - -import ( - "embed" - "strings" -) - -//go:embed prompts/*.md -var promptsFS embed.FS - -// Prompt returns the content of an embedded prompt file. -// Name should be without the .md extension (e.g., "commit"). -func Prompt(name string) string { - data, err := promptsFS.ReadFile("prompts/" + name + ".md") - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) -} diff --git a/pkg/lifecycle/embed_test.go b/pkg/lifecycle/embed_test.go deleted file mode 100644 index 715261e..0000000 --- a/pkg/lifecycle/embed_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package lifecycle - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestPrompt_Good_CommitExists(t *testing.T) { - content := Prompt("commit") - assert.NotEmpty(t, content, "commit prompt should exist") - assert.Contains(t, content, "Commit") -} - -func TestPrompt_Bad_NonexistentReturnsEmpty(t *testing.T) { - content := Prompt("nonexistent-prompt-that-does-not-exist") - assert.Empty(t, content, "nonexistent prompt should return empty string") -} - -func TestPrompt_Good_ContentIsTrimmed(t *testing.T) { - content := Prompt("commit") - // Should not start or end with whitespace. - assert.Equal(t, content[0:1] != " " && content[0:1] != "\n", true, "should not start with whitespace") - lastChar := content[len(content)-1:] - assert.Equal(t, lastChar != " " && lastChar != "\n", true, "should not end with whitespace") -} diff --git a/pkg/lifecycle/events.go b/pkg/lifecycle/events.go deleted file mode 100644 index 60a6b59..0000000 --- a/pkg/lifecycle/events.go +++ /dev/null @@ -1,114 +0,0 @@ -package lifecycle - -import ( - "context" - "sync" - "time" -) - -// EventType identifies the kind of lifecycle event. -type EventType string - -const ( - // EventTaskDispatched is emitted when a task is successfully routed and claimed. - EventTaskDispatched EventType = "task_dispatched" - // EventTaskClaimed is emitted when a task claim succeeds via the API client. - EventTaskClaimed EventType = "task_claimed" - // EventDispatchFailedNoAgent is emitted when no eligible agent is available. - EventDispatchFailedNoAgent EventType = "dispatch_failed_no_agent" - // EventDispatchFailedQuota is emitted when an agent's quota is exceeded. - EventDispatchFailedQuota EventType = "dispatch_failed_quota" - // EventTaskDeadLettered is emitted when a task exceeds its retry limit. - EventTaskDeadLettered EventType = "task_dead_lettered" - // EventQuotaWarning is emitted when an agent reaches 80%+ quota usage. - EventQuotaWarning EventType = "quota_warning" - // EventQuotaExceeded is emitted when an agent exceeds their quota. - EventQuotaExceeded EventType = "quota_exceeded" - // EventUsageRecorded is emitted when usage is recorded for an agent. - EventUsageRecorded EventType = "usage_recorded" -) - -// Event represents a lifecycle event in the agentic system. -type Event struct { - // Type identifies what happened. - Type EventType `json:"type"` - // TaskID is the task involved, if any. - TaskID string `json:"task_id,omitempty"` - // AgentID is the agent involved, if any. - AgentID string `json:"agent_id,omitempty"` - // Timestamp is when the event occurred. - Timestamp time.Time `json:"timestamp"` - // Payload carries additional event-specific data. - Payload any `json:"payload,omitempty"` -} - -// EventEmitter is the interface for publishing lifecycle events. -type EventEmitter interface { - // Emit publishes an event. Implementations should be non-blocking. - Emit(ctx context.Context, event Event) error -} - -// ChannelEmitter is an in-process EventEmitter backed by a buffered channel. -// Events are dropped (not blocked) when the buffer is full. -type ChannelEmitter struct { - ch chan Event -} - -// NewChannelEmitter creates a ChannelEmitter with the given buffer size. -func NewChannelEmitter(bufSize int) *ChannelEmitter { - if bufSize < 1 { - bufSize = 64 - } - return &ChannelEmitter{ch: make(chan Event, bufSize)} -} - -// Emit sends an event to the channel. If the buffer is full, the event is -// dropped silently to avoid blocking the dispatch path. -func (e *ChannelEmitter) Emit(_ context.Context, event Event) error { - select { - case e.ch <- event: - default: - // Buffer full — drop the event rather than blocking. - } - return nil -} - -// Events returns the underlying channel for consumers to read from. -func (e *ChannelEmitter) Events() <-chan Event { - return e.ch -} - -// Close closes the underlying channel, signalling consumers to stop reading. -func (e *ChannelEmitter) Close() { - close(e.ch) -} - -// MultiEmitter fans out events to multiple emitters. Emission continues even -// if one emitter fails — errors are collected but not returned. -type MultiEmitter struct { - mu sync.RWMutex - emitters []EventEmitter -} - -// NewMultiEmitter creates a MultiEmitter that fans out to the given emitters. -func NewMultiEmitter(emitters ...EventEmitter) *MultiEmitter { - return &MultiEmitter{emitters: emitters} -} - -// Emit sends the event to all registered emitters. Non-blocking: each emitter -// is called in sequence but ChannelEmitter.Emit is itself non-blocking. -func (m *MultiEmitter) Emit(ctx context.Context, event Event) error { - m.mu.RLock() - defer m.mu.RUnlock() - for _, em := range m.emitters { - _ = em.Emit(ctx, event) - } - return nil -} - -// Add appends an emitter to the fan-out list. -func (m *MultiEmitter) Add(emitter EventEmitter) { - m.mu.Lock() - defer m.mu.Unlock() - m.emitters = append(m.emitters, emitter) -} diff --git a/pkg/lifecycle/events_integration_test.go b/pkg/lifecycle/events_integration_test.go deleted file mode 100644 index 0ee714e..0000000 --- a/pkg/lifecycle/events_integration_test.go +++ /dev/null @@ -1,283 +0,0 @@ -package lifecycle - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Dispatcher event emission tests --- - -func TestDispatcher_EmitsTaskDispatched(t *testing.T) { - em := NewChannelEmitter(10) - d, reg, store := setupDispatcher(t, nil) - d.SetEventEmitter(em) - registerAgent(t, reg, store, "agent-1", []string{"go"}, 5) - - task := &Task{ID: "t1", Labels: []string{"go"}} - agentID, err := d.Dispatch(context.Background(), task) - require.NoError(t, err) - assert.Equal(t, "agent-1", agentID) - - // Should have received EventTaskDispatched. - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventTaskDispatched, got[0].Type) - assert.Equal(t, "t1", got[0].TaskID) - assert.Equal(t, "agent-1", got[0].AgentID) -} - -func TestDispatcher_EmitsDispatchFailedNoAgent(t *testing.T) { - em := NewChannelEmitter(10) - d, _, _ := setupDispatcher(t, nil) - d.SetEventEmitter(em) - // No agents registered. - - task := &Task{ID: "t2", Labels: []string{"go"}} - _, err := d.Dispatch(context.Background(), task) - require.Error(t, err) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventDispatchFailedNoAgent, got[0].Type) - assert.Equal(t, "t2", got[0].TaskID) -} - -func TestDispatcher_EmitsDispatchFailedQuota(t *testing.T) { - em := NewChannelEmitter(10) - d, reg, store := setupDispatcher(t, nil) - d.SetEventEmitter(em) - - // Register agent with zero daily job limit (will be exceeded immediately). - _ = reg.Register(AgentInfo{ - ID: "agent-q", Name: "agent-q", Capabilities: []string{"go"}, - Status: AgentAvailable, LastHeartbeat: time.Now().UTC(), MaxLoad: 5, - }) - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-q", - DailyJobLimit: 1, - ConcurrentJobs: 5, - }) - // Use up the single job. - _ = store.IncrementUsage("agent-q", 0, 1) - - task := &Task{ID: "t3", Labels: []string{"go"}} - _, err := d.Dispatch(context.Background(), task) - require.Error(t, err) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventDispatchFailedQuota, got[0].Type) - assert.Equal(t, "t3", got[0].TaskID) - assert.Equal(t, "agent-q", got[0].AgentID) -} - -func TestDispatcher_NoEventsWithoutEmitter(t *testing.T) { - // Verify no panic when emitter is nil. - d, reg, store := setupDispatcher(t, nil) - registerAgent(t, reg, store, "agent-1", []string{"go"}, 5) - - task := &Task{ID: "t4", Labels: []string{"go"}} - _, err := d.Dispatch(context.Background(), task) - require.NoError(t, err) - // No panic = pass. -} - -// --- AllowanceService event emission tests --- - -func TestAllowanceService_EmitsQuotaExceeded(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100, - }) - // Use all tokens. - _ = store.IncrementUsage("agent-1", 100, 0) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventQuotaExceeded, got[0].Type) - assert.Equal(t, "agent-1", got[0].AgentID) -} - -func TestAllowanceService_EmitsQuotaWarning(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100, - }) - // Use 85% of tokens — should trigger warning. - _ = store.IncrementUsage("agent-1", 85, 0) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.True(t, result.Allowed) - assert.Equal(t, AllowanceWarning, result.Status) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventQuotaWarning, got[0].Type) - assert.Equal(t, "agent-1", got[0].AgentID) -} - -func TestAllowanceService_EmitsUsageRecorded(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{AgentID: "agent-1"}) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Event: QuotaEventJobStarted, - Timestamp: time.Now().UTC(), - }) - require.NoError(t, err) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventUsageRecorded, got[0].Type) - assert.Equal(t, "agent-1", got[0].AgentID) -} - -func TestAllowanceService_EmitsUsageRecordedOnCompletion(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{AgentID: "agent-1"}) - // Start a job first. - _ = store.IncrementUsage("agent-1", 0, 1) - - err := svc.RecordUsage(UsageReport{ - AgentID: "agent-1", - JobID: "job-1", - Model: "claude-sonnet", - TokensIn: 500, - TokensOut: 200, - Event: QuotaEventJobCompleted, - Timestamp: time.Now().UTC(), - }) - require.NoError(t, err) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventUsageRecorded, got[0].Type) -} - -func TestAllowanceService_QuotaExceededOnJobLimit(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyJobLimit: 2, - }) - _ = store.IncrementUsage("agent-1", 0, 2) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventQuotaExceeded, got[0].Type) - assert.Contains(t, got[0].Payload, "daily job limit") -} - -func TestAllowanceService_QuotaExceededOnConcurrent(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ConcurrentJobs: 1, - }) - _ = store.IncrementUsage("agent-1", 0, 1) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventQuotaExceeded, got[0].Type) - assert.Contains(t, got[0].Payload, "concurrent") -} - -func TestAllowanceService_QuotaExceededOnModelAllowlist(t *testing.T) { - em := NewChannelEmitter(10) - store := NewMemoryStore() - svc := NewAllowanceService(store) - svc.SetEventEmitter(em) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - ModelAllowlist: []string{"claude-sonnet"}, - }) - - result, err := svc.Check("agent-1", "gpt-4") - require.NoError(t, err) - assert.False(t, result.Allowed) - - got := drainEvents(em, 1, time.Second) - require.Len(t, got, 1) - assert.Equal(t, EventQuotaExceeded, got[0].Type) - assert.Contains(t, got[0].Payload, "allowlist") -} - -func TestAllowanceService_NoEventsWithoutEmitter(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - // No emitter set. - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "agent-1", - DailyTokenLimit: 100, - }) - _ = store.IncrementUsage("agent-1", 100, 0) - - result, err := svc.Check("agent-1", "") - require.NoError(t, err) - assert.False(t, result.Allowed) - // No panic = pass. -} - -// --- Helpers --- - -// drainEvents reads up to n events from the emitter within the timeout. -func drainEvents(em *ChannelEmitter, n int, timeout time.Duration) []Event { - var events []Event - deadline := time.After(timeout) - for range n { - select { - case e := <-em.Events(): - events = append(events, e) - case <-deadline: - return events - } - } - return events -} diff --git a/pkg/lifecycle/events_test.go b/pkg/lifecycle/events_test.go deleted file mode 100644 index b6f3f5a..0000000 --- a/pkg/lifecycle/events_test.go +++ /dev/null @@ -1,153 +0,0 @@ -package lifecycle - -import ( - "context" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestChannelEmitter_EmitAndReceive(t *testing.T) { - em := NewChannelEmitter(10) - ctx := context.Background() - - event := Event{ - Type: EventTaskDispatched, - TaskID: "task-1", - AgentID: "agent-1", - Timestamp: time.Now().UTC(), - Payload: "test payload", - } - - err := em.Emit(ctx, event) - require.NoError(t, err) - - select { - case got := <-em.Events(): - assert.Equal(t, EventTaskDispatched, got.Type) - assert.Equal(t, "task-1", got.TaskID) - assert.Equal(t, "agent-1", got.AgentID) - assert.Equal(t, "test payload", got.Payload) - case <-time.After(time.Second): - t.Fatal("timed out waiting for event") - } -} - -func TestChannelEmitter_BufferOverflowDrops(t *testing.T) { - em := NewChannelEmitter(2) - ctx := context.Background() - - // Fill the buffer. - require.NoError(t, em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "1"})) - require.NoError(t, em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "2"})) - - // Third event should be dropped, not block. - err := em.Emit(ctx, Event{Type: EventTaskDispatched, TaskID: "3"}) - require.NoError(t, err) - - // Only 2 events in the channel. - assert.Len(t, em.ch, 2) -} - -func TestChannelEmitter_DefaultBufferSize(t *testing.T) { - em := NewChannelEmitter(0) - assert.Equal(t, 64, cap(em.ch)) -} - -func TestMultiEmitter_FanOut(t *testing.T) { - em1 := NewChannelEmitter(10) - em2 := NewChannelEmitter(10) - multi := NewMultiEmitter(em1, em2) - ctx := context.Background() - - event := Event{ - Type: EventQuotaWarning, - AgentID: "agent-x", - } - err := multi.Emit(ctx, event) - require.NoError(t, err) - - // Both emitters should have received the event. - select { - case got := <-em1.Events(): - assert.Equal(t, EventQuotaWarning, got.Type) - case <-time.After(time.Second): - t.Fatal("em1: timed out") - } - - select { - case got := <-em2.Events(): - assert.Equal(t, EventQuotaWarning, got.Type) - case <-time.After(time.Second): - t.Fatal("em2: timed out") - } -} - -func TestMultiEmitter_Add(t *testing.T) { - em1 := NewChannelEmitter(10) - multi := NewMultiEmitter(em1) - ctx := context.Background() - - em2 := NewChannelEmitter(10) - multi.Add(em2) - - err := multi.Emit(ctx, Event{Type: EventUsageRecorded}) - require.NoError(t, err) - - assert.Len(t, em1.ch, 1) - assert.Len(t, em2.ch, 1) -} - -func TestMultiEmitter_ContinuesOnFailure(t *testing.T) { - failing := &failingEmitter{} - good := NewChannelEmitter(10) - multi := NewMultiEmitter(failing, good) - ctx := context.Background() - - err := multi.Emit(ctx, Event{Type: EventTaskClaimed}) - require.NoError(t, err) // MultiEmitter swallows errors. - - // The good emitter should still have received the event. - assert.Len(t, good.ch, 1) -} - -func TestChannelEmitter_ConcurrentEmit(t *testing.T) { - em := NewChannelEmitter(100) - ctx := context.Background() - - var wg sync.WaitGroup - for range 50 { - wg.Go(func() { - _ = em.Emit(ctx, Event{Type: EventTaskDispatched}) - }) - } - wg.Wait() - - assert.Equal(t, 50, len(em.ch)) -} - -func TestEventTypes_AllDefined(t *testing.T) { - types := []EventType{ - EventTaskDispatched, - EventTaskClaimed, - EventDispatchFailedNoAgent, - EventDispatchFailedQuota, - EventTaskDeadLettered, - EventQuotaWarning, - EventQuotaExceeded, - EventUsageRecorded, - } - for _, et := range types { - assert.NotEmpty(t, string(et)) - } -} - -// failingEmitter always returns an error. -type failingEmitter struct{} - -func (f *failingEmitter) Emit(_ context.Context, _ Event) error { - return &APIError{Code: 500, Message: "emitter failed"} -} diff --git a/pkg/lifecycle/lifecycle_test.go b/pkg/lifecycle/lifecycle_test.go deleted file mode 100644 index 3e67504..0000000 --- a/pkg/lifecycle/lifecycle_test.go +++ /dev/null @@ -1,302 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TestTaskLifecycle_ClaimProcessComplete tests the full task lifecycle: -// claim a pending task, check allowance, record usage events, complete the task. -func TestTaskLifecycle_ClaimProcessComplete(t *testing.T) { - // Set up allowance infrastructure. - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "lifecycle-agent", - DailyTokenLimit: 100000, - DailyJobLimit: 10, - ConcurrentJobs: 3, - }) - - // Phase 1: Pre-dispatch allowance check should pass. - check, err := svc.Check("lifecycle-agent", "") - require.NoError(t, err) - assert.True(t, check.Allowed) - assert.Equal(t, AllowanceOK, check.Status) - - // Phase 2: Simulate claiming a task via the HTTP client. - pendingTask := Task{ - ID: "lifecycle-001", - Title: "Full lifecycle test", - Priority: PriorityHigh, - Status: StatusPending, - Project: "core", - } - - claimedTask := pendingTask - claimedTask.Status = StatusInProgress - claimedTask.ClaimedBy = "lifecycle-agent" - now := time.Now().UTC() - claimedTask.ClaimedAt = &now - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch { - case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/lifecycle-001/claim": - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(ClaimResponse{Task: &claimedTask}) - - case r.Method == http.MethodPatch && r.URL.Path == "/api/tasks/lifecycle-001": - w.WriteHeader(http.StatusOK) - - case r.Method == http.MethodPost && r.URL.Path == "/api/tasks/lifecycle-001/complete": - w.WriteHeader(http.StatusOK) - - default: - w.WriteHeader(http.StatusNotFound) - } - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - client.AgentID = "lifecycle-agent" - - // Claim the task. - claimed, err := client.ClaimTask(context.Background(), "lifecycle-001") - require.NoError(t, err) - assert.Equal(t, StatusInProgress, claimed.Status) - assert.Equal(t, "lifecycle-agent", claimed.ClaimedBy) - - // Phase 3: Record job start in the allowance system. - err = svc.RecordUsage(UsageReport{ - AgentID: "lifecycle-agent", - JobID: "lifecycle-001", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - usage, _ := store.GetUsage("lifecycle-agent") - assert.Equal(t, 1, usage.ActiveJobs) - assert.Equal(t, 1, usage.JobsStarted) - - // Phase 4: Update task progress. - err = client.UpdateTask(context.Background(), "lifecycle-001", TaskUpdate{ - Status: StatusInProgress, - Progress: 50, - Notes: "Halfway through", - }) - require.NoError(t, err) - - // Phase 5: Record job completion with token usage. - err = svc.RecordUsage(UsageReport{ - AgentID: "lifecycle-agent", - JobID: "lifecycle-001", - Model: "claude-sonnet", - TokensIn: 5000, - TokensOut: 3000, - Event: QuotaEventJobCompleted, - }) - require.NoError(t, err) - - usage, _ = store.GetUsage("lifecycle-agent") - assert.Equal(t, 0, usage.ActiveJobs) - assert.Equal(t, int64(8000), usage.TokensUsed) - - // Phase 6: Complete the task via the API. - err = client.CompleteTask(context.Background(), "lifecycle-001", TaskResult{ - Success: true, - Output: "Task completed successfully", - Artifacts: []string{"output.go"}, - }) - require.NoError(t, err) - - // Phase 7: Verify allowance is still within limits. - check, err = svc.Check("lifecycle-agent", "") - require.NoError(t, err) - assert.True(t, check.Allowed) - assert.Equal(t, AllowanceOK, check.Status) - assert.Equal(t, int64(92000), check.RemainingTokens) - assert.Equal(t, 9, check.RemainingJobs) -} - -// TestTaskLifecycle_ClaimProcessFail tests the lifecycle when a job fails -// and verifies that 50% of tokens are returned. -func TestTaskLifecycle_ClaimProcessFail(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "fail-agent", - DailyTokenLimit: 50000, - DailyJobLimit: 5, - ConcurrentJobs: 2, - }) - - // Start job. - err := svc.RecordUsage(UsageReport{ - AgentID: "fail-agent", - JobID: "fail-001", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - // Job fails with 10000 tokens consumed. - err = svc.RecordUsage(UsageReport{ - AgentID: "fail-agent", - JobID: "fail-001", - Model: "claude-sonnet", - TokensIn: 6000, - TokensOut: 4000, - Event: QuotaEventJobFailed, - }) - require.NoError(t, err) - - // Verify 50% returned: 10000 charged, 5000 returned = 5000 net. - usage, _ := store.GetUsage("fail-agent") - assert.Equal(t, int64(5000), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) - - // Verify model usage is net: 10000 - 5000 = 5000. - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(5000), modelUsage) - - // Check allowance - should still have room. - check, err := svc.Check("fail-agent", "") - require.NoError(t, err) - assert.True(t, check.Allowed) - assert.Equal(t, int64(45000), check.RemainingTokens) -} - -// TestTaskLifecycle_ClaimProcessCancel tests the lifecycle when a job is -// cancelled and verifies that 100% of tokens are returned. -func TestTaskLifecycle_ClaimProcessCancel(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "cancel-agent", - DailyTokenLimit: 50000, - DailyJobLimit: 5, - ConcurrentJobs: 2, - }) - - // Start job. - err := svc.RecordUsage(UsageReport{ - AgentID: "cancel-agent", - JobID: "cancel-001", - Event: QuotaEventJobStarted, - }) - require.NoError(t, err) - - // Job cancelled with 8000 tokens consumed. - err = svc.RecordUsage(UsageReport{ - AgentID: "cancel-agent", - JobID: "cancel-001", - TokensIn: 5000, - TokensOut: 3000, - Event: QuotaEventJobCancelled, - }) - require.NoError(t, err) - - // Verify 100% returned: tokens should be 0 (only job start had 0 tokens). - usage, _ := store.GetUsage("cancel-agent") - assert.Equal(t, int64(0), usage.TokensUsed) - assert.Equal(t, 0, usage.ActiveJobs) - - // Model usage should be zero for cancelled jobs. - modelUsage, _ := store.GetModelUsage("claude-sonnet") - assert.Equal(t, int64(0), modelUsage) -} - -// TestTaskLifecycle_MultipleAgentsConcurrent verifies that multiple agents -// can operate on the same store concurrently without data races. -func TestTaskLifecycle_MultipleAgentsConcurrent(t *testing.T) { - store := NewMemoryStore() - svc := NewAllowanceService(store) - - agents := []string{"agent-a", "agent-b", "agent-c"} - for _, agentID := range agents { - _ = store.SetAllowance(&AgentAllowance{ - AgentID: agentID, - DailyTokenLimit: 100000, - DailyJobLimit: 50, - ConcurrentJobs: 5, - }) - } - - var wg sync.WaitGroup - for _, agentID := range agents { - wg.Add(1) - go func(aid string) { - defer wg.Done() - for range 10 { - // Check allowance. - result, err := svc.Check(aid, "") - assert.NoError(t, err) - assert.True(t, result.Allowed) - - // Start job. - _ = svc.RecordUsage(UsageReport{ - AgentID: aid, - JobID: aid + "-job", - Event: QuotaEventJobStarted, - }) - - // Complete job. - _ = svc.RecordUsage(UsageReport{ - AgentID: aid, - JobID: aid + "-job", - Model: "claude-sonnet", - TokensIn: 100, - TokensOut: 50, - Event: QuotaEventJobCompleted, - }) - } - }(agentID) - } - wg.Wait() - - // Verify each agent has consistent usage. - for _, agentID := range agents { - usage, err := store.GetUsage(agentID) - require.NoError(t, err) - assert.Equal(t, int64(1500), usage.TokensUsed) // 10 jobs x 150 tokens - assert.Equal(t, 10, usage.JobsStarted) // 10 starts - assert.Equal(t, 0, usage.ActiveJobs) // all completed - } -} - -// TestTaskLifecycle_ClaimedByFilter verifies that ListTasks with ClaimedBy -// filter sends the correct query parameter. -func TestTaskLifecycle_ClaimedByFilter(t *testing.T) { - claimedTask := Task{ - ID: "claimed-task-1", - Title: "Agent's task", - Status: StatusInProgress, - ClaimedBy: "agent-x", - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "agent-x", r.URL.Query().Get("claimed_by")) - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode([]Task{claimedTask}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - tasks, err := client.ListTasks(context.Background(), ListOptions{ - ClaimedBy: "agent-x", - }) - - require.NoError(t, err) - require.Len(t, tasks, 1) - assert.Equal(t, "agent-x", tasks[0].ClaimedBy) -} diff --git a/pkg/lifecycle/logs.go b/pkg/lifecycle/logs.go deleted file mode 100644 index b3e378c..0000000 --- a/pkg/lifecycle/logs.go +++ /dev/null @@ -1,47 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "io" - "time" - - "forge.lthn.ai/core/go-log" -) - -// StreamLogs polls a task's status and writes updates to writer. It polls at -// the given interval until the task reaches a terminal state (completed or -// blocked) or the context is cancelled. Returns ctx.Err() on cancellation. -func StreamLogs(ctx context.Context, client *Client, taskID string, interval time.Duration, writer io.Writer) error { - const op = "agentic.StreamLogs" - - if taskID == "" { - return log.E(op, "task ID is required", nil) - } - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - task, err := client.GetTask(ctx, taskID) - if err != nil { - // Write the error but continue polling -- transient failures - // should not stop the stream. - _, _ = fmt.Fprintf(writer, "[%s] Error: %s\n", time.Now().UTC().Format("2006-01-02 15:04:05"), err) - continue - } - - line := fmt.Sprintf("[%s] Status: %s", time.Now().UTC().Format("2006-01-02 15:04:05"), task.Status) - _, _ = fmt.Fprintln(writer, line) - - // Stop on terminal states. - if task.Status == StatusCompleted || task.Status == StatusBlocked { - return nil - } - } - } -} diff --git a/pkg/lifecycle/logs_test.go b/pkg/lifecycle/logs_test.go deleted file mode 100644 index 1f85cd9..0000000 --- a/pkg/lifecycle/logs_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestStreamLogs_Good_CompletedTask(t *testing.T) { - var calls atomic.Int32 - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/api/tasks/task-1", r.URL.Path) - n := calls.Add(1) - - task := Task{ID: "task-1"} - switch { - case n <= 2: - task.Status = StatusInProgress - default: - task.Status = StatusCompleted - } - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - var buf bytes.Buffer - - err := StreamLogs(context.Background(), client, "task-1", 10*time.Millisecond, &buf) - require.NoError(t, err) - - output := buf.String() - assert.Contains(t, output, "Status: in_progress") - assert.Contains(t, output, "Status: completed") - assert.GreaterOrEqual(t, int(calls.Load()), 3) -} - -func TestStreamLogs_Good_BlockedTask(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - task := Task{ID: "task-2", Status: StatusBlocked} - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - var buf bytes.Buffer - - err := StreamLogs(context.Background(), client, "task-2", 10*time.Millisecond, &buf) - require.NoError(t, err) - assert.Contains(t, buf.String(), "Status: blocked") -} - -func TestStreamLogs_Good_ContextCancellation(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - task := Task{ID: "task-3", Status: StatusInProgress} - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - var buf bytes.Buffer - - ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) - defer cancel() - - err := StreamLogs(ctx, client, "task-3", 20*time.Millisecond, &buf) - require.ErrorIs(t, err, context.DeadlineExceeded) - assert.Contains(t, buf.String(), "Status: in_progress") -} - -func TestStreamLogs_Good_TransientErrorContinues(t *testing.T) { - var calls atomic.Int32 - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - n := calls.Add(1) - - if n == 1 { - // First call: server error. - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "transient"}) - return - } - // Second call: completed. - task := Task{ID: "task-4", Status: StatusCompleted} - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - var buf bytes.Buffer - - err := StreamLogs(context.Background(), client, "task-4", 10*time.Millisecond, &buf) - require.NoError(t, err) - - output := buf.String() - assert.Contains(t, output, "Error:") - assert.Contains(t, output, "Status: completed") -} - -func TestStreamLogs_Bad_EmptyTaskID(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - var buf bytes.Buffer - - err := StreamLogs(context.Background(), client, "", 10*time.Millisecond, &buf) - assert.Error(t, err) - assert.Contains(t, err.Error(), "task ID is required") -} - -func TestStreamLogs_Good_ImmediateCancel(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - task := Task{ID: "task-5", Status: StatusInProgress} - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - var buf bytes.Buffer - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately. - - err := StreamLogs(ctx, client, "task-5", 10*time.Millisecond, &buf) - require.ErrorIs(t, err, context.Canceled) -} diff --git a/pkg/lifecycle/plan_dispatcher.go b/pkg/lifecycle/plan_dispatcher.go deleted file mode 100644 index 26ecbc1..0000000 --- a/pkg/lifecycle/plan_dispatcher.go +++ /dev/null @@ -1,197 +0,0 @@ -package lifecycle - -import ( - "context" - "time" - - "forge.lthn.ai/core/go-log" -) - -// PlanDispatcher orchestrates plan-based work by polling active plans, -// starting sessions, and routing work to agents. It wraps the existing -// agent registry, router, and allowance service alongside the API client. -type PlanDispatcher struct { - registry AgentRegistry - router TaskRouter - allowance *AllowanceService - client *Client - events EventEmitter - agentType string // e.g. "opus", "haiku", "codex" -} - -// NewPlanDispatcher creates a PlanDispatcher for the given agent type. -func NewPlanDispatcher( - agentType string, - registry AgentRegistry, - router TaskRouter, - allowance *AllowanceService, - client *Client, -) *PlanDispatcher { - return &PlanDispatcher{ - agentType: agentType, - registry: registry, - router: router, - allowance: allowance, - client: client, - } -} - -// SetEventEmitter attaches an event emitter for lifecycle notifications. -func (pd *PlanDispatcher) SetEventEmitter(em EventEmitter) { - pd.events = em -} - -func (pd *PlanDispatcher) emit(ctx context.Context, event Event) { - if pd.events != nil { - if event.Timestamp.IsZero() { - event.Timestamp = time.Now().UTC() - } - _ = pd.events.Emit(ctx, event) - } -} - -// PlanDispatchLoop polls for active plans at the given interval and picks up -// the first plan with a pending or in-progress phase. It starts a session, -// marks the phase in-progress, and returns the plan + session for the caller -// to work on. Runs until context is cancelled. -func (pd *PlanDispatcher) PlanDispatchLoop(ctx context.Context, interval time.Duration) error { - const op = "PlanDispatcher.PlanDispatchLoop" - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - plan, session, err := pd.pickUpWork(ctx) - if err != nil { - _ = log.E(op, "failed to pick up work", err) - continue - } - if plan == nil { - continue // no work available - } - - pd.emit(ctx, Event{ - Type: EventTaskDispatched, - TaskID: plan.Slug, - AgentID: session.SessionID, - Payload: map[string]string{ - "plan": plan.Slug, - "agent_type": pd.agentType, - }, - }) - } - } -} - -// pickUpWork finds the first active plan with workable phases, starts a session, -// and marks the next phase in-progress. Returns nil if no work is available. -func (pd *PlanDispatcher) pickUpWork(ctx context.Context) (*Plan, *sessionStartResponse, error) { - const op = "PlanDispatcher.pickUpWork" - - plans, err := pd.client.ListPlans(ctx, ListPlanOptions{Status: PlanActive}) - if err != nil { - return nil, nil, log.E(op, "failed to list active plans", err) - } - - for _, plan := range plans { - // Check agent allowance before taking work. - if pd.allowance != nil { - check, err := pd.allowance.Check(pd.agentType, "") - if err != nil || !check.Allowed { - continue - } - } - - // Get full plan with phases. - fullPlan, err := pd.client.GetPlan(ctx, plan.Slug) - if err != nil { - _ = log.E(op, "failed to get plan "+plan.Slug, err) - continue - } - - // Find the next workable phase. - phase := nextWorkablePhase(fullPlan.Phases) - if phase == nil { - continue - } - - // Start session for this plan. - session, err := pd.client.StartSession(ctx, StartSessionRequest{ - AgentType: pd.agentType, - PlanSlug: plan.Slug, - }) - if err != nil { - _ = log.E(op, "failed to start session for "+plan.Slug, err) - continue - } - - // Mark phase as in-progress. - if phase.Status == PhasePending { - if err := pd.client.UpdatePhaseStatus(ctx, plan.Slug, phase.Name, PhaseInProgress, ""); err != nil { - _ = log.E(op, "failed to update phase status", err) - } - } - - // Record job start. - if pd.allowance != nil { - _ = pd.allowance.RecordUsage(UsageReport{ - AgentID: pd.agentType, - JobID: plan.Slug, - Event: QuotaEventJobStarted, - Timestamp: time.Now().UTC(), - }) - } - - return fullPlan, session, nil - } - - return nil, nil, nil -} - -// CompleteWork ends a session and optionally marks the current phase as completed. -func (pd *PlanDispatcher) CompleteWork(ctx context.Context, planSlug, sessionID, phaseName string, summary string) error { - const op = "PlanDispatcher.CompleteWork" - - // Mark phase completed. - if phaseName != "" { - if err := pd.client.UpdatePhaseStatus(ctx, planSlug, phaseName, PhaseCompleted, ""); err != nil { - _ = log.E(op, "failed to complete phase", err) - } - } - - // End session. - if err := pd.client.EndSession(ctx, sessionID, "completed", summary); err != nil { - return log.E(op, "failed to end session", err) - } - - // Record job completion. - if pd.allowance != nil { - _ = pd.allowance.RecordUsage(UsageReport{ - AgentID: pd.agentType, - JobID: planSlug, - Event: QuotaEventJobCompleted, - Timestamp: time.Now().UTC(), - }) - } - - return nil -} - -// nextWorkablePhase returns the first phase that is pending or in-progress. -func nextWorkablePhase(phases []Phase) *Phase { - for i := range phases { - switch phases[i].Status { - case PhasePending: - if phases[i].CanStart { - return &phases[i] - } - case PhaseInProgress: - return &phases[i] - } - } - return nil -} diff --git a/pkg/lifecycle/plans.go b/pkg/lifecycle/plans.go deleted file mode 100644 index 483d1c9..0000000 --- a/pkg/lifecycle/plans.go +++ /dev/null @@ -1,525 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - - "forge.lthn.ai/core/go-log" -) - -// PlanStatus represents the state of a plan. -type PlanStatus string - -const ( - PlanDraft PlanStatus = "draft" - PlanActive PlanStatus = "active" - PlanPaused PlanStatus = "paused" - PlanCompleted PlanStatus = "completed" - PlanArchived PlanStatus = "archived" -) - -// PhaseStatus represents the state of a phase within a plan. -type PhaseStatus string - -const ( - PhasePending PhaseStatus = "pending" - PhaseInProgress PhaseStatus = "in_progress" - PhaseCompleted PhaseStatus = "completed" - PhaseBlocked PhaseStatus = "blocked" - PhaseSkipped PhaseStatus = "skipped" -) - -// Plan represents an agent plan from the PHP API. -type Plan struct { - Slug string `json:"slug"` - Title string `json:"title"` - Description string `json:"description,omitempty"` - Status PlanStatus `json:"status"` - CurrentPhase int `json:"current_phase,omitempty"` - Progress Progress `json:"progress,omitempty"` - Phases []Phase `json:"phases,omitempty"` - Metadata any `json:"metadata,omitempty"` - CreatedAt string `json:"created_at,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` -} - -// Phase represents a phase within a plan. -type Phase struct { - ID int `json:"id,omitempty"` - Order int `json:"order"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - Status PhaseStatus `json:"status"` - Tasks []PhaseTask `json:"tasks,omitempty"` - TaskProgress TaskProgress `json:"task_progress,omitempty"` - RemainingTasks []string `json:"remaining_tasks,omitempty"` - Dependencies []int `json:"dependencies,omitempty"` - DependencyBlockers []string `json:"dependency_blockers,omitempty"` - CanStart bool `json:"can_start,omitempty"` - Checkpoints []any `json:"checkpoints,omitempty"` - StartedAt string `json:"started_at,omitempty"` - CompletedAt string `json:"completed_at,omitempty"` - Metadata any `json:"metadata,omitempty"` -} - -// PhaseTask represents a task within a phase. Tasks are stored as a JSON array -// in the phase and may be simple strings or objects with status/notes. -type PhaseTask struct { - Name string `json:"name"` - Status string `json:"status,omitempty"` - Notes string `json:"notes,omitempty"` -} - -// UnmarshalJSON handles the fact that tasks can be either strings or objects. -func (t *PhaseTask) UnmarshalJSON(data []byte) error { - // Try string first - var s string - if err := json.Unmarshal(data, &s); err == nil { - t.Name = s - t.Status = "pending" - return nil - } - - // Try object - type taskAlias PhaseTask - var obj taskAlias - if err := json.Unmarshal(data, &obj); err != nil { - return err - } - *t = PhaseTask(obj) - return nil -} - -// Progress represents plan progress metrics. -type Progress struct { - Total int `json:"total"` - Completed int `json:"completed"` - InProgress int `json:"in_progress"` - Pending int `json:"pending"` - Percentage int `json:"percentage"` -} - -// TaskProgress represents task-level progress within a phase. -type TaskProgress struct { - Total int `json:"total"` - Completed int `json:"completed"` - Pending int `json:"pending"` - Percentage int `json:"percentage"` -} - -// ListPlanOptions specifies filters for listing plans. -type ListPlanOptions struct { - Status PlanStatus `json:"status,omitempty"` - IncludeArchived bool `json:"include_archived,omitempty"` -} - -// CreatePlanRequest is the payload for creating a new plan. -type CreatePlanRequest struct { - Title string `json:"title"` - Slug string `json:"slug,omitempty"` - Description string `json:"description,omitempty"` - Context map[string]any `json:"context,omitempty"` - Phases []CreatePhaseInput `json:"phases,omitempty"` -} - -// CreatePhaseInput is a phase definition for plan creation. -type CreatePhaseInput struct { - Name string `json:"name"` - Description string `json:"description,omitempty"` - Tasks []string `json:"tasks,omitempty"` -} - -// planListResponse wraps the list endpoint response. -type planListResponse struct { - Plans []Plan `json:"plans"` - Total int `json:"total"` -} - -// planCreateResponse wraps the create endpoint response. -type planCreateResponse struct { - Slug string `json:"slug"` - Title string `json:"title"` - Status string `json:"status"` - Phases int `json:"phases"` -} - -// planUpdateResponse wraps the update endpoint response. -type planUpdateResponse struct { - Slug string `json:"slug"` - Status string `json:"status"` -} - -// planArchiveResponse wraps the archive endpoint response. -type planArchiveResponse struct { - Slug string `json:"slug"` - Status string `json:"status"` - ArchivedAt string `json:"archived_at,omitempty"` -} - -// ListPlans retrieves plans matching the given options. -func (c *Client) ListPlans(ctx context.Context, opts ListPlanOptions) ([]Plan, error) { - const op = "agentic.Client.ListPlans" - - params := url.Values{} - if opts.Status != "" { - params.Set("status", string(opts.Status)) - } - if opts.IncludeArchived { - params.Set("include_archived", "1") - } - - endpoint := c.BaseURL + "/v1/plans" - if len(params) > 0 { - endpoint += "?" + params.Encode() - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result planListResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return result.Plans, nil -} - -// GetPlan retrieves a plan by slug (returns full detail with phases). -func (c *Client) GetPlan(ctx context.Context, slug string) (*Plan, error) { - const op = "agentic.Client.GetPlan" - - if slug == "" { - return nil, log.E(op, "plan slug is required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug)) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var plan Plan - if err := json.NewDecoder(resp.Body).Decode(&plan); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &plan, nil -} - -// CreatePlan creates a new plan with optional phases. -func (c *Client) CreatePlan(ctx context.Context, req CreatePlanRequest) (*planCreateResponse, error) { - const op = "agentic.Client.CreatePlan" - - if req.Title == "" { - return nil, log.E(op, "title is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/plans" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result planCreateResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} - -// UpdatePlanStatus changes a plan's status. -func (c *Client) UpdatePlanStatus(ctx context.Context, slug string, status PlanStatus) error { - const op = "agentic.Client.UpdatePlanStatus" - - if slug == "" { - return log.E(op, "plan slug is required", nil) - } - - data, err := json.Marshal(map[string]string{"status": string(status)}) - if err != nil { - return log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// ArchivePlan archives a plan with an optional reason. -func (c *Client) ArchivePlan(ctx context.Context, slug string, reason string) error { - const op = "agentic.Client.ArchivePlan" - - if slug == "" { - return log.E(op, "plan slug is required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s", c.BaseURL, url.PathEscape(slug)) - - var body *bytes.Reader - if reason != "" { - data, _ := json.Marshal(map[string]string{"reason": reason}) - body = bytes.NewReader(data) - } - - var reqBody *bytes.Reader - if body != nil { - reqBody = body - } - - var httpReq *http.Request - var err error - if reqBody != nil { - httpReq, err = http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, reqBody) - if err != nil { - return log.E(op, "failed to create request", err) - } - httpReq.Header.Set("Content-Type", "application/json") - } else { - httpReq, err = http.NewRequestWithContext(ctx, http.MethodDelete, endpoint, nil) - if err != nil { - return log.E(op, "failed to create request", err) - } - } - c.setHeaders(httpReq) - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// GetPhase retrieves a specific phase within a plan. -func (c *Client) GetPhase(ctx context.Context, planSlug string, phase string) (*Phase, error) { - const op = "agentic.Client.GetPhase" - - if planSlug == "" || phase == "" { - return nil, log.E(op, "plan slug and phase identifier are required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s", - c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase)) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result Phase - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} - -// UpdatePhaseStatus changes a phase's status. -func (c *Client) UpdatePhaseStatus(ctx context.Context, planSlug, phase string, status PhaseStatus, notes string) error { - const op = "agentic.Client.UpdatePhaseStatus" - - if planSlug == "" || phase == "" { - return log.E(op, "plan slug and phase identifier are required", nil) - } - - payload := map[string]string{"status": string(status)} - if notes != "" { - payload["notes"] = notes - } - data, err := json.Marshal(payload) - if err != nil { - return log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s", - c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// AddCheckpoint adds a checkpoint note to a phase. -func (c *Client) AddCheckpoint(ctx context.Context, planSlug, phase, note string, checkpointCtx map[string]any) error { - const op = "agentic.Client.AddCheckpoint" - - if planSlug == "" || phase == "" || note == "" { - return log.E(op, "plan slug, phase, and note are required", nil) - } - - payload := map[string]any{"note": note} - if len(checkpointCtx) > 0 { - payload["context"] = checkpointCtx - } - data, err := json.Marshal(payload) - if err != nil { - return log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/checkpoint", - c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// UpdateTaskStatus updates a task within a phase. -func (c *Client) UpdateTaskStatus(ctx context.Context, planSlug, phase string, taskIdx int, status string, notes string) error { - const op = "agentic.Client.UpdateTaskStatus" - - if planSlug == "" || phase == "" { - return log.E(op, "plan slug and phase are required", nil) - } - - payload := map[string]any{} - if status != "" { - payload["status"] = status - } - if notes != "" { - payload["notes"] = notes - } - data, err := json.Marshal(payload) - if err != nil { - return log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/tasks/%d", - c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase), taskIdx) - - req, err := http.NewRequestWithContext(ctx, http.MethodPatch, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// ToggleTask toggles a task between pending and completed. -func (c *Client) ToggleTask(ctx context.Context, planSlug, phase string, taskIdx int) error { - const op = "agentic.Client.ToggleTask" - - if planSlug == "" || phase == "" { - return log.E(op, "plan slug and phase are required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/plans/%s/phases/%s/tasks/%d/toggle", - c.BaseURL, url.PathEscape(planSlug), url.PathEscape(phase), taskIdx) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} diff --git a/pkg/lifecycle/prompts/commit.md b/pkg/lifecycle/prompts/commit.md deleted file mode 100644 index 58d6d7b..0000000 --- a/pkg/lifecycle/prompts/commit.md +++ /dev/null @@ -1,44 +0,0 @@ -# Commit Instructions - -## Context - -When asked to commit changes in these repositories: - -1. **All changes are pre-approved** - Both tracked (modified) and untracked files have been reviewed and approved for commit -2. **Include everything** - Commit ALL modified files AND all untracked files (including new directories) -3. **Don't skip files** - Do not leave out untracked files assuming they need separate review -4. **Single commit preferred** - Combine all changes into one cohesive commit unless explicitly told otherwise - -The user has already validated these changes. Proceed with confidence. - -## Handling Ignored Files - -If you see untracked directories that should typically be ignored (like `node_modules/`, `vendor/`, `.cache/`, `dist/`, `build/`): - -1. **Fix the .gitignore** - Create or update `.gitignore` to exclude these directories -2. **Commit the .gitignore** - Include this fix in your commit -3. **Don't ask** - Just fix it and commit - -Common patterns to add to .gitignore: -``` -node_modules/ -vendor/ -.cache/ -dist/ -build/ -*.log -.env -.DS_Store -``` - -## Commit Message Style - -- Use conventional commit format: `type(scope): description` -- Common types: `refactor`, `feat`, `fix`, `docs`, `chore` -- Keep the first line under 72 characters -- Add body for complex changes explaining the "why" -- Include `Co-Authored-By: Claude Opus 4.5 ` - -## Task - -Review the uncommitted changes and create an appropriate commit. Be concise. diff --git a/pkg/lifecycle/registry.go b/pkg/lifecycle/registry.go deleted file mode 100644 index d9de2a4..0000000 --- a/pkg/lifecycle/registry.go +++ /dev/null @@ -1,168 +0,0 @@ -package lifecycle - -import ( - "iter" - "slices" - "sync" - "time" -) - -// AgentStatus represents the availability state of an agent. -type AgentStatus string - -const ( - // AgentAvailable indicates the agent is ready to accept tasks. - AgentAvailable AgentStatus = "available" - // AgentBusy indicates the agent is working but may accept more tasks. - AgentBusy AgentStatus = "busy" - // AgentOffline indicates the agent has not sent a heartbeat recently. - AgentOffline AgentStatus = "offline" -) - -// AgentInfo describes a registered agent and its current state. -type AgentInfo struct { - // ID is the unique identifier for the agent. - ID string `json:"id"` - // Name is the human-readable name of the agent. - Name string `json:"name"` - // Capabilities lists what the agent can handle (e.g. "go", "testing", "frontend"). - Capabilities []string `json:"capabilities,omitempty"` - // Status is the current availability state. - Status AgentStatus `json:"status"` - // LastHeartbeat is the last time the agent reported in. - LastHeartbeat time.Time `json:"last_heartbeat"` - // CurrentLoad is the number of active jobs the agent is running. - CurrentLoad int `json:"current_load"` - // MaxLoad is the maximum concurrent jobs the agent supports. 0 means unlimited. - MaxLoad int `json:"max_load"` -} - -// AgentRegistry manages the set of known agents and their health. -type AgentRegistry interface { - // Register adds or updates an agent in the registry. - Register(agent AgentInfo) error - // Deregister removes an agent from the registry. - Deregister(id string) error - // Get returns a copy of the agent info for the given ID. - Get(id string) (AgentInfo, error) - // List returns a copy of all registered agents. - List() []AgentInfo - // All returns an iterator over all registered agents. - All() iter.Seq[AgentInfo] - // Heartbeat updates the agent's LastHeartbeat timestamp and sets status - // to Available if the agent was previously Offline. - Heartbeat(id string) error - // Reap marks agents as Offline if their last heartbeat is older than ttl. - // Returns the IDs of agents that were reaped. - Reap(ttl time.Duration) []string - // Reaped returns an iterator over the IDs of agents that were reaped. - Reaped(ttl time.Duration) iter.Seq[string] -} - -// MemoryRegistry is an in-memory AgentRegistry implementation guarded by a -// read-write mutex. It uses copy-on-read semantics consistent with MemoryStore. -type MemoryRegistry struct { - mu sync.RWMutex - agents map[string]*AgentInfo -} - -// NewMemoryRegistry creates a new in-memory agent registry. -func NewMemoryRegistry() *MemoryRegistry { - return &MemoryRegistry{ - agents: make(map[string]*AgentInfo), - } -} - -// Register adds or updates an agent in the registry. Returns an error if the -// agent ID is empty. -func (r *MemoryRegistry) Register(agent AgentInfo) error { - if agent.ID == "" { - return &APIError{Code: 400, Message: "agent ID is required"} - } - r.mu.Lock() - defer r.mu.Unlock() - cp := agent - r.agents[agent.ID] = &cp - return nil -} - -// Deregister removes an agent from the registry. Returns an error if the agent -// is not found. -func (r *MemoryRegistry) Deregister(id string) error { - r.mu.Lock() - defer r.mu.Unlock() - if _, ok := r.agents[id]; !ok { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - delete(r.agents, id) - return nil -} - -// Get returns a copy of the agent info for the given ID. Returns an error if -// the agent is not found. -func (r *MemoryRegistry) Get(id string) (AgentInfo, error) { - r.mu.RLock() - defer r.mu.RUnlock() - a, ok := r.agents[id] - if !ok { - return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + id} - } - return *a, nil -} - -// List returns a copy of all registered agents. -func (r *MemoryRegistry) List() []AgentInfo { - return slices.Collect(r.All()) -} - -// All returns an iterator over all registered agents. -func (r *MemoryRegistry) All() iter.Seq[AgentInfo] { - return func(yield func(AgentInfo) bool) { - r.mu.RLock() - defer r.mu.RUnlock() - for _, a := range r.agents { - if !yield(*a) { - return - } - } - } -} - -// Heartbeat updates the agent's LastHeartbeat timestamp. If the agent was -// Offline, it transitions to Available. -func (r *MemoryRegistry) Heartbeat(id string) error { - r.mu.Lock() - defer r.mu.Unlock() - a, ok := r.agents[id] - if !ok { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - a.LastHeartbeat = time.Now().UTC() - if a.Status == AgentOffline { - a.Status = AgentAvailable - } - return nil -} - -// Reap marks agents as Offline if their last heartbeat is older than ttl. -// Returns the IDs of agents that were reaped. -func (r *MemoryRegistry) Reap(ttl time.Duration) []string { - return slices.Collect(r.Reaped(ttl)) -} - -// Reaped returns an iterator over the IDs of agents that were reaped. -func (r *MemoryRegistry) Reaped(ttl time.Duration) iter.Seq[string] { - return func(yield func(string) bool) { - r.mu.Lock() - defer r.mu.Unlock() - now := time.Now().UTC() - for id, a := range r.agents { - if a.Status != AgentOffline && now.Sub(a.LastHeartbeat) > ttl { - a.Status = AgentOffline - if !yield(id) { - return - } - } - } - } -} diff --git a/pkg/lifecycle/registry_redis.go b/pkg/lifecycle/registry_redis.go deleted file mode 100644 index 1a00ed8..0000000 --- a/pkg/lifecycle/registry_redis.go +++ /dev/null @@ -1,280 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "errors" - "iter" - "slices" - "time" - - "github.com/redis/go-redis/v9" -) - -// RedisRegistry implements AgentRegistry using Redis as the backing store. -// It provides persistent, network-accessible agent registration suitable for -// multi-node deployments. Heartbeat refreshes key TTL for natural reaping via -// expiry. -type RedisRegistry struct { - client *redis.Client - prefix string - defaultTTL time.Duration -} - -// redisRegistryConfig holds the configuration for a RedisRegistry. -type redisRegistryConfig struct { - password string - db int - prefix string - ttl time.Duration -} - -// RedisRegistryOption is a functional option for configuring a RedisRegistry. -type RedisRegistryOption func(*redisRegistryConfig) - -// WithRegistryRedisPassword sets the password for authenticating with Redis. -func WithRegistryRedisPassword(pw string) RedisRegistryOption { - return func(c *redisRegistryConfig) { - c.password = pw - } -} - -// WithRegistryRedisDB selects the Redis database number. -func WithRegistryRedisDB(db int) RedisRegistryOption { - return func(c *redisRegistryConfig) { - c.db = db - } -} - -// WithRegistryRedisPrefix sets the key prefix for all Redis keys. -// Default: "agentic". -func WithRegistryRedisPrefix(prefix string) RedisRegistryOption { - return func(c *redisRegistryConfig) { - c.prefix = prefix - } -} - -// WithRegistryTTL sets the default TTL for agent keys. Default: 5 minutes. -// Heartbeat refreshes this TTL. Agents whose keys expire are naturally reaped. -func WithRegistryTTL(ttl time.Duration) RedisRegistryOption { - return func(c *redisRegistryConfig) { - c.ttl = ttl - } -} - -// NewRedisRegistry creates a new Redis-backed agent registry connecting to the -// given address (host:port). It pings the server to verify connectivity. -func NewRedisRegistry(addr string, opts ...RedisRegistryOption) (*RedisRegistry, error) { - cfg := &redisRegistryConfig{ - prefix: "agentic", - ttl: 5 * time.Minute, - } - for _, opt := range opts { - opt(cfg) - } - - client := redis.NewClient(&redis.Options{ - Addr: addr, - Password: cfg.password, - DB: cfg.db, - }) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := client.Ping(ctx).Err(); err != nil { - _ = client.Close() - return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()} - } - - return &RedisRegistry{ - client: client, - prefix: cfg.prefix, - defaultTTL: cfg.ttl, - }, nil -} - -// Close releases the underlying Redis connection. -func (r *RedisRegistry) Close() error { - return r.client.Close() -} - -// --- key helpers --- - -func (r *RedisRegistry) agentKey(id string) string { - return r.prefix + ":agent:" + id -} - -func (r *RedisRegistry) agentPattern() string { - return r.prefix + ":agent:*" -} - -// --- AgentRegistry interface --- - -// Register adds or updates an agent in the registry. -func (r *RedisRegistry) Register(agent AgentInfo) error { - if agent.ID == "" { - return &APIError{Code: 400, Message: "agent ID is required"} - } - ctx := context.Background() - data, err := json.Marshal(agent) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()} - } - if err := r.client.Set(ctx, r.agentKey(agent.ID), data, r.defaultTTL).Err(); err != nil { - return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()} - } - return nil -} - -// Deregister removes an agent from the registry. Returns an error if the agent -// is not found. -func (r *RedisRegistry) Deregister(id string) error { - ctx := context.Background() - n, err := r.client.Del(ctx, r.agentKey(id)).Result() - if err != nil { - return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()} - } - if n == 0 { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - return nil -} - -// Get returns a copy of the agent info for the given ID. Returns an error if -// the agent is not found. -func (r *RedisRegistry) Get(id string) (AgentInfo, error) { - ctx := context.Background() - val, err := r.client.Get(ctx, r.agentKey(id)).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + id} - } - return AgentInfo{}, &APIError{Code: 500, Message: "failed to get agent: " + err.Error()} - } - var a AgentInfo - if err := json.Unmarshal([]byte(val), &a); err != nil { - return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()} - } - return a, nil -} - -// List returns a copy of all registered agents by scanning all agent keys. -func (r *RedisRegistry) List() []AgentInfo { - return slices.Collect(r.All()) -} - -// All returns an iterator over all registered agents. -func (r *RedisRegistry) All() iter.Seq[AgentInfo] { - return func(yield func(AgentInfo) bool) { - ctx := context.Background() - iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator() - for iter.Next(ctx) { - val, err := r.client.Get(ctx, iter.Val()).Result() - if err != nil { - continue - } - var a AgentInfo - if err := json.Unmarshal([]byte(val), &a); err != nil { - continue - } - if !yield(a) { - return - } - } - } -} - -// Heartbeat updates the agent's LastHeartbeat timestamp and refreshes the key -// TTL. If the agent was Offline, it transitions to Available. -func (r *RedisRegistry) Heartbeat(id string) error { - ctx := context.Background() - key := r.agentKey(id) - - val, err := r.client.Get(ctx, key).Result() - if err != nil { - if errors.Is(err, redis.Nil) { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - return &APIError{Code: 500, Message: "failed to get agent for heartbeat: " + err.Error()} - } - - var a AgentInfo - if err := json.Unmarshal([]byte(val), &a); err != nil { - return &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()} - } - - a.LastHeartbeat = time.Now().UTC() - if a.Status == AgentOffline { - a.Status = AgentAvailable - } - - data, err := json.Marshal(a) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()} - } - - if err := r.client.Set(ctx, key, data, r.defaultTTL).Err(); err != nil { - return &APIError{Code: 500, Message: "failed to update agent heartbeat: " + err.Error()} - } - return nil -} - -// Reap scans all agent keys and marks agents as Offline if their last heartbeat -// is older than ttl. This is a backup to natural TTL expiry. Returns the IDs -// of agents that were reaped. -func (r *RedisRegistry) Reap(ttl time.Duration) []string { - return slices.Collect(r.Reaped(ttl)) -} - -// Reaped returns an iterator over the IDs of agents that were reaped. -func (r *RedisRegistry) Reaped(ttl time.Duration) iter.Seq[string] { - return func(yield func(string) bool) { - ctx := context.Background() - cutoff := time.Now().UTC().Add(-ttl) - - iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator() - for iter.Next(ctx) { - key := iter.Val() - val, err := r.client.Get(ctx, key).Result() - if err != nil { - continue - } - var a AgentInfo - if err := json.Unmarshal([]byte(val), &a); err != nil { - continue - } - - if a.Status != AgentOffline && a.LastHeartbeat.Before(cutoff) { - a.Status = AgentOffline - data, err := json.Marshal(a) - if err != nil { - continue - } - // Preserve remaining TTL (or use default if none). - remainingTTL, err := r.client.TTL(ctx, key).Result() - if err != nil || remainingTTL <= 0 { - remainingTTL = r.defaultTTL - } - if err := r.client.Set(ctx, key, data, remainingTTL).Err(); err != nil { - continue - } - if !yield(a.ID) { - return - } - } - } - } -} - -// FlushPrefix deletes all keys matching the registry's prefix. Useful for -// testing cleanup. -func (r *RedisRegistry) FlushPrefix(ctx context.Context) error { - iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator() - for iter.Next(ctx) { - if err := r.client.Del(ctx, iter.Val()).Err(); err != nil { - return err - } - } - return iter.Err() -} diff --git a/pkg/lifecycle/registry_redis_test.go b/pkg/lifecycle/registry_redis_test.go deleted file mode 100644 index 76baf82..0000000 --- a/pkg/lifecycle/registry_redis_test.go +++ /dev/null @@ -1,327 +0,0 @@ -package lifecycle - -import ( - "context" - "fmt" - "sort" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// newTestRedisRegistry creates a RedisRegistry with a unique prefix for test isolation. -// Skips the test if Redis is unreachable. -func newTestRedisRegistry(t *testing.T) *RedisRegistry { - t.Helper() - prefix := fmt.Sprintf("test_reg_%d", time.Now().UnixNano()) - reg, err := NewRedisRegistry(testRedisAddr, - WithRegistryRedisPrefix(prefix), - WithRegistryTTL(5*time.Minute), - ) - if err != nil { - t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) - } - t.Cleanup(func() { - ctx := context.Background() - _ = reg.FlushPrefix(ctx) - _ = reg.Close() - }) - return reg -} - -// --- Register tests --- - -func TestRedisRegistry_Register_Good(t *testing.T) { - reg := newTestRedisRegistry(t) - err := reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Test Agent", - Capabilities: []string{"go", "testing"}, - Status: AgentAvailable, - MaxLoad: 5, - }) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", got.ID) - assert.Equal(t, "Test Agent", got.Name) - assert.Equal(t, []string{"go", "testing"}, got.Capabilities) - assert.Equal(t, AgentAvailable, got.Status) - assert.Equal(t, 5, got.MaxLoad) -} - -func TestRedisRegistry_Register_Good_Overwrite(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3}) - err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10}) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "Updated", got.Name) - assert.Equal(t, 10, got.MaxLoad) -} - -func TestRedisRegistry_Register_Bad_EmptyID(t *testing.T) { - reg := newTestRedisRegistry(t) - err := reg.Register(AgentInfo{ID: "", Name: "No ID"}) - require.Error(t, err) - assert.Contains(t, err.Error(), "agent ID is required") -} - -// --- Deregister tests --- - -func TestRedisRegistry_Deregister_Good(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"}) - - err := reg.Deregister("agent-1") - require.NoError(t, err) - - _, err = reg.Get("agent-1") - require.Error(t, err) -} - -func TestRedisRegistry_Deregister_Bad_NotFound(t *testing.T) { - reg := newTestRedisRegistry(t) - err := reg.Deregister("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Get tests --- - -func TestRedisRegistry_Get_Good(t *testing.T) { - reg := newTestRedisRegistry(t) - now := time.Now().UTC().Truncate(time.Millisecond) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Getter", - Status: AgentBusy, - CurrentLoad: 2, - MaxLoad: 5, - LastHeartbeat: now, - }) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, AgentBusy, got.Status) - assert.Equal(t, 2, got.CurrentLoad) - assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) -} - -func TestRedisRegistry_Get_Bad_NotFound(t *testing.T) { - reg := newTestRedisRegistry(t) - _, err := reg.Get("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -func TestRedisRegistry_Get_Good_ReturnsCopy(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1}) - - got, _ := reg.Get("agent-1") - got.CurrentLoad = 99 - got.Name = "Tampered" - - // Re-read — should be unchanged (deserialized from Redis). - again, _ := reg.Get("agent-1") - assert.Equal(t, "Original", again.Name) - assert.Equal(t, 1, again.CurrentLoad) -} - -// --- List tests --- - -func TestRedisRegistry_List_Good_Empty(t *testing.T) { - reg := newTestRedisRegistry(t) - agents := reg.List() - assert.Empty(t, agents) -} - -func TestRedisRegistry_List_Good_Multiple(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"}) - _ = reg.Register(AgentInfo{ID: "b", Name: "Beta"}) - _ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"}) - - agents := reg.List() - assert.Len(t, agents, 3) - - // Sort by ID for deterministic assertion. - sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID }) - assert.Equal(t, "a", agents[0].ID) - assert.Equal(t, "b", agents[1].ID) - assert.Equal(t, "c", agents[2].ID) -} - -// --- Heartbeat tests --- - -func TestRedisRegistry_Heartbeat_Good(t *testing.T) { - reg := newTestRedisRegistry(t) - past := time.Now().UTC().Add(-5 * time.Minute) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentAvailable, - LastHeartbeat: past, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.True(t, got.LastHeartbeat.After(past)) - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestRedisRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentOffline, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestRedisRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) { - reg := newTestRedisRegistry(t) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentBusy, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentBusy, got.Status) -} - -func TestRedisRegistry_Heartbeat_Bad_NotFound(t *testing.T) { - reg := newTestRedisRegistry(t) - err := reg.Heartbeat("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Reap tests --- - -func TestRedisRegistry_Reap_Good_StaleAgent(t *testing.T) { - reg := newTestRedisRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - fresh := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale}) - _ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "stale-1") - - got, _ := reg.Get("stale-1") - assert.Equal(t, AgentOffline, got.Status) - - got, _ = reg.Get("fresh-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestRedisRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) { - reg := newTestRedisRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestRedisRegistry_Reap_Good_NoStaleAgents(t *testing.T) { - reg := newTestRedisRegistry(t) - now := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now}) - _ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestRedisRegistry_Reap_Good_BusyAgentReaped(t *testing.T) { - reg := newTestRedisRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "busy-stale") - - got, _ := reg.Get("busy-stale") - assert.Equal(t, AgentOffline, got.Status) -} - -// --- Concurrent access --- - -func TestRedisRegistry_Concurrent_Good(t *testing.T) { - reg := newTestRedisRegistry(t) - - var wg sync.WaitGroup - for i := range 20 { - wg.Add(1) - go func(n int) { - defer wg.Done() - id := "agent-" + string(rune('a'+n%5)) - _ = reg.Register(AgentInfo{ - ID: id, - Name: "Concurrent", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - }) - _, _ = reg.Get(id) - _ = reg.Heartbeat(id) - _ = reg.List() - _ = reg.Reap(1 * time.Minute) - }(i) - } - wg.Wait() - - // No race conditions — test passes under -race. - agents := reg.List() - assert.True(t, len(agents) > 0) -} - -// --- Constructor error case --- - -func TestNewRedisRegistry_Bad_Unreachable(t *testing.T) { - _, err := NewRedisRegistry("127.0.0.1:1") // almost certainly unreachable - require.Error(t, err) - apiErr, ok := err.(*APIError) - require.True(t, ok, "expected *APIError") - assert.Equal(t, 500, apiErr.Code) - assert.Contains(t, err.Error(), "failed to connect to Redis") -} - -// --- Config-based factory with redis backend --- - -func TestNewAgentRegistryFromConfig_Good_Redis(t *testing.T) { - cfg := RegistryConfig{ - RegistryBackend: "redis", - RegistryRedisAddr: testRedisAddr, - } - reg, err := NewAgentRegistryFromConfig(cfg) - if err != nil { - t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) - } - rr, ok := reg.(*RedisRegistry) - assert.True(t, ok, "expected RedisRegistry") - _ = rr.Close() -} diff --git a/pkg/lifecycle/registry_sqlite.go b/pkg/lifecycle/registry_sqlite.go deleted file mode 100644 index 2692b8c..0000000 --- a/pkg/lifecycle/registry_sqlite.go +++ /dev/null @@ -1,267 +0,0 @@ -package lifecycle - -import ( - "database/sql" - "encoding/json" - "iter" - "slices" - "strings" - "sync" - "time" - - _ "modernc.org/sqlite" -) - -// SQLiteRegistry implements AgentRegistry using a SQLite database. -// It provides persistent storage that survives process restarts. -type SQLiteRegistry struct { - db *sql.DB - mu sync.Mutex // serialises read-modify-write operations -} - -// NewSQLiteRegistry creates a new SQLite-backed agent registry at the given path. -// Use ":memory:" for tests that do not need persistence. -func NewSQLiteRegistry(dbPath string) (*SQLiteRegistry, error) { - db, err := sql.Open("sqlite", dbPath) - if err != nil { - return nil, &APIError{Code: 500, Message: "failed to open SQLite registry: " + err.Error()} - } - db.SetMaxOpenConns(1) - if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { - db.Close() - return nil, &APIError{Code: 500, Message: "failed to set WAL mode: " + err.Error()} - } - if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { - db.Close() - return nil, &APIError{Code: 500, Message: "failed to set busy_timeout: " + err.Error()} - } - if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS agents ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL DEFAULT '', - capabilities TEXT NOT NULL DEFAULT '[]', - status TEXT NOT NULL DEFAULT 'available', - last_heartbeat DATETIME NOT NULL DEFAULT (datetime('now')), - current_load INTEGER NOT NULL DEFAULT 0, - max_load INTEGER NOT NULL DEFAULT 0, - registered_at DATETIME NOT NULL DEFAULT (datetime('now')) - )`); err != nil { - db.Close() - return nil, &APIError{Code: 500, Message: "failed to create agents table: " + err.Error()} - } - return &SQLiteRegistry{db: db}, nil -} - -// Close releases the underlying SQLite database. -func (r *SQLiteRegistry) Close() error { - return r.db.Close() -} - -// Register adds or updates an agent in the registry. Returns an error if the -// agent ID is empty. -func (r *SQLiteRegistry) Register(agent AgentInfo) error { - if agent.ID == "" { - return &APIError{Code: 400, Message: "agent ID is required"} - } - caps, err := json.Marshal(agent.Capabilities) - if err != nil { - return &APIError{Code: 500, Message: "failed to marshal capabilities: " + err.Error()} - } - hb := agent.LastHeartbeat - if hb.IsZero() { - hb = time.Now().UTC() - } - r.mu.Lock() - defer r.mu.Unlock() - _, err = r.db.Exec(`INSERT INTO agents (id, name, capabilities, status, last_heartbeat, current_load, max_load, registered_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(id) DO UPDATE SET - name = excluded.name, - capabilities = excluded.capabilities, - status = excluded.status, - last_heartbeat = excluded.last_heartbeat, - current_load = excluded.current_load, - max_load = excluded.max_load`, - agent.ID, agent.Name, string(caps), string(agent.Status), hb.Format(time.RFC3339Nano), - agent.CurrentLoad, agent.MaxLoad, hb.Format(time.RFC3339Nano)) - if err != nil { - return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()} - } - return nil -} - -// Deregister removes an agent from the registry. Returns an error if the agent -// is not found. -func (r *SQLiteRegistry) Deregister(id string) error { - r.mu.Lock() - defer r.mu.Unlock() - res, err := r.db.Exec("DELETE FROM agents WHERE id = ?", id) - if err != nil { - return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()} - } - n, err := res.RowsAffected() - if err != nil { - return &APIError{Code: 500, Message: "failed to check delete result: " + err.Error()} - } - if n == 0 { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - return nil -} - -// Get returns a copy of the agent info for the given ID. Returns an error if -// the agent is not found. -func (r *SQLiteRegistry) Get(id string) (AgentInfo, error) { - return r.scanAgent("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents WHERE id = ?", id) -} - -// List returns a copy of all registered agents. -func (r *SQLiteRegistry) List() []AgentInfo { - return slices.Collect(r.All()) -} - -// All returns an iterator over all registered agents. -func (r *SQLiteRegistry) All() iter.Seq[AgentInfo] { - return func(yield func(AgentInfo) bool) { - rows, err := r.db.Query("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents") - if err != nil { - return - } - defer rows.Close() - - for rows.Next() { - a, err := r.scanAgentRow(rows) - if err != nil { - continue - } - if !yield(a) { - return - } - } - } -} - -// Heartbeat updates the agent's LastHeartbeat timestamp. If the agent was -// Offline, it transitions to Available. -func (r *SQLiteRegistry) Heartbeat(id string) error { - r.mu.Lock() - defer r.mu.Unlock() - - now := time.Now().UTC().Format(time.RFC3339Nano) - - // Update heartbeat for all agents, and transition offline agents to available. - res, err := r.db.Exec(`UPDATE agents SET - last_heartbeat = ?, - status = CASE WHEN status = ? THEN ? ELSE status END - WHERE id = ?`, - now, string(AgentOffline), string(AgentAvailable), id) - if err != nil { - return &APIError{Code: 500, Message: "failed to heartbeat agent: " + err.Error()} - } - n, err := res.RowsAffected() - if err != nil { - return &APIError{Code: 500, Message: "failed to check heartbeat result: " + err.Error()} - } - if n == 0 { - return &APIError{Code: 404, Message: "agent not found: " + id} - } - return nil -} - -// Reap marks agents as Offline if their last heartbeat is older than ttl. -// Returns the IDs of agents that were reaped. -func (r *SQLiteRegistry) Reap(ttl time.Duration) []string { - return slices.Collect(r.Reaped(ttl)) -} - -// Reaped returns an iterator over the IDs of agents that were reaped. -func (r *SQLiteRegistry) Reaped(ttl time.Duration) iter.Seq[string] { - return func(yield func(string) bool) { - r.mu.Lock() - defer r.mu.Unlock() - - cutoff := time.Now().UTC().Add(-ttl).Format(time.RFC3339Nano) - - // Select agents that will be reaped before updating. - rows, err := r.db.Query( - "SELECT id FROM agents WHERE status != ? AND last_heartbeat < ?", - string(AgentOffline), cutoff) - if err != nil { - return - } - defer rows.Close() - - var reaped []string - for rows.Next() { - var id string - if err := rows.Scan(&id); err != nil { - continue - } - reaped = append(reaped, id) - } - if err := rows.Err(); err != nil { - return - } - rows.Close() - - if len(reaped) > 0 { - // Build placeholders for IN clause. - placeholders := make([]string, len(reaped)) - args := make([]any, len(reaped)) - for i, id := range reaped { - placeholders[i] = "?" - args[i] = id - } - query := "UPDATE agents SET status = ? WHERE id IN (" + strings.Join(placeholders, ",") + ")" - allArgs := append([]any{string(AgentOffline)}, args...) - _, _ = r.db.Exec(query, allArgs...) - - for _, id := range reaped { - if !yield(id) { - return - } - } - } - } -} - -// --- internal helpers --- - -// scanAgent executes a query that returns a single agent row. -func (r *SQLiteRegistry) scanAgent(query string, args ...any) (AgentInfo, error) { - row := r.db.QueryRow(query, args...) - var a AgentInfo - var capsJSON string - var statusStr string - var hbStr string - err := row.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad) - if err == sql.ErrNoRows { - return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + args[0].(string)} - } - if err != nil { - return AgentInfo{}, &APIError{Code: 500, Message: "failed to scan agent: " + err.Error()} - } - if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil { - return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal capabilities: " + err.Error()} - } - a.Status = AgentStatus(statusStr) - a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr) - return a, nil -} - -// scanAgentRow scans a single row from a rows iterator. -func (r *SQLiteRegistry) scanAgentRow(rows *sql.Rows) (AgentInfo, error) { - var a AgentInfo - var capsJSON string - var statusStr string - var hbStr string - err := rows.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad) - if err != nil { - return AgentInfo{}, err - } - if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil { - return AgentInfo{}, err - } - a.Status = AgentStatus(statusStr) - a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr) - return a, nil -} diff --git a/pkg/lifecycle/registry_sqlite_test.go b/pkg/lifecycle/registry_sqlite_test.go deleted file mode 100644 index 2b2f594..0000000 --- a/pkg/lifecycle/registry_sqlite_test.go +++ /dev/null @@ -1,386 +0,0 @@ -package lifecycle - -import ( - "path/filepath" - "sort" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// newTestSQLiteRegistry creates a SQLiteRegistry backed by :memory: for testing. -func newTestSQLiteRegistry(t *testing.T) *SQLiteRegistry { - t.Helper() - reg, err := NewSQLiteRegistry(":memory:") - require.NoError(t, err) - t.Cleanup(func() { _ = reg.Close() }) - return reg -} - -// --- Register tests --- - -func TestSQLiteRegistry_Register_Good(t *testing.T) { - reg := newTestSQLiteRegistry(t) - err := reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Test Agent", - Capabilities: []string{"go", "testing"}, - Status: AgentAvailable, - MaxLoad: 5, - }) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", got.ID) - assert.Equal(t, "Test Agent", got.Name) - assert.Equal(t, []string{"go", "testing"}, got.Capabilities) - assert.Equal(t, AgentAvailable, got.Status) - assert.Equal(t, 5, got.MaxLoad) -} - -func TestSQLiteRegistry_Register_Good_Overwrite(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3}) - err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10}) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "Updated", got.Name) - assert.Equal(t, 10, got.MaxLoad) -} - -func TestSQLiteRegistry_Register_Bad_EmptyID(t *testing.T) { - reg := newTestSQLiteRegistry(t) - err := reg.Register(AgentInfo{ID: "", Name: "No ID"}) - require.Error(t, err) - assert.Contains(t, err.Error(), "agent ID is required") -} - -func TestSQLiteRegistry_Register_Good_NilCapabilities(t *testing.T) { - reg := newTestSQLiteRegistry(t) - err := reg.Register(AgentInfo{ - ID: "agent-1", - Name: "No Caps", - Capabilities: nil, - Status: AgentAvailable, - }) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "No Caps", got.Name) - // nil capabilities serialised as JSON null, deserialised back to nil. -} - -// --- Deregister tests --- - -func TestSQLiteRegistry_Deregister_Good(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"}) - - err := reg.Deregister("agent-1") - require.NoError(t, err) - - _, err = reg.Get("agent-1") - require.Error(t, err) -} - -func TestSQLiteRegistry_Deregister_Bad_NotFound(t *testing.T) { - reg := newTestSQLiteRegistry(t) - err := reg.Deregister("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Get tests --- - -func TestSQLiteRegistry_Get_Good(t *testing.T) { - reg := newTestSQLiteRegistry(t) - now := time.Now().UTC().Truncate(time.Microsecond) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Getter", - Status: AgentBusy, - CurrentLoad: 2, - MaxLoad: 5, - LastHeartbeat: now, - }) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, AgentBusy, got.Status) - assert.Equal(t, 2, got.CurrentLoad) - // Heartbeat stored via RFC3339Nano — allow small time difference from serialisation. - assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) -} - -func TestSQLiteRegistry_Get_Bad_NotFound(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _, err := reg.Get("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -func TestSQLiteRegistry_Get_Good_ReturnsCopy(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1}) - - got, _ := reg.Get("agent-1") - got.CurrentLoad = 99 - got.Name = "Tampered" - - // Re-read — should be unchanged. - again, _ := reg.Get("agent-1") - assert.Equal(t, "Original", again.Name) - assert.Equal(t, 1, again.CurrentLoad) -} - -// --- List tests --- - -func TestSQLiteRegistry_List_Good_Empty(t *testing.T) { - reg := newTestSQLiteRegistry(t) - agents := reg.List() - assert.Empty(t, agents) -} - -func TestSQLiteRegistry_List_Good_Multiple(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"}) - _ = reg.Register(AgentInfo{ID: "b", Name: "Beta"}) - _ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"}) - - agents := reg.List() - assert.Len(t, agents, 3) - - // Sort by ID for deterministic assertion. - sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID }) - assert.Equal(t, "a", agents[0].ID) - assert.Equal(t, "b", agents[1].ID) - assert.Equal(t, "c", agents[2].ID) -} - -// --- Heartbeat tests --- - -func TestSQLiteRegistry_Heartbeat_Good(t *testing.T) { - reg := newTestSQLiteRegistry(t) - past := time.Now().UTC().Add(-5 * time.Minute) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentAvailable, - LastHeartbeat: past, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.True(t, got.LastHeartbeat.After(past)) - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestSQLiteRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentOffline, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestSQLiteRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) { - reg := newTestSQLiteRegistry(t) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentBusy, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentBusy, got.Status) -} - -func TestSQLiteRegistry_Heartbeat_Bad_NotFound(t *testing.T) { - reg := newTestSQLiteRegistry(t) - err := reg.Heartbeat("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Reap tests --- - -func TestSQLiteRegistry_Reap_Good_StaleAgent(t *testing.T) { - reg := newTestSQLiteRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - fresh := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale}) - _ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "stale-1") - - got, _ := reg.Get("stale-1") - assert.Equal(t, AgentOffline, got.Status) - - got, _ = reg.Get("fresh-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestSQLiteRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) { - reg := newTestSQLiteRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestSQLiteRegistry_Reap_Good_NoStaleAgents(t *testing.T) { - reg := newTestSQLiteRegistry(t) - now := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now}) - _ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestSQLiteRegistry_Reap_Good_BusyAgentReaped(t *testing.T) { - reg := newTestSQLiteRegistry(t) - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "busy-stale") - - got, _ := reg.Get("busy-stale") - assert.Equal(t, AgentOffline, got.Status) -} - -// --- Concurrent access --- - -func TestSQLiteRegistry_Concurrent_Good(t *testing.T) { - reg := newTestSQLiteRegistry(t) - - var wg sync.WaitGroup - for i := range 20 { - wg.Add(1) - go func(n int) { - defer wg.Done() - id := "agent-" + string(rune('a'+n%5)) - _ = reg.Register(AgentInfo{ - ID: id, - Name: "Concurrent", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - }) - _, _ = reg.Get(id) - _ = reg.Heartbeat(id) - _ = reg.List() - _ = reg.Reap(1 * time.Minute) - }(i) - } - wg.Wait() - - // No race conditions — test passes under -race. - agents := reg.List() - assert.True(t, len(agents) > 0) -} - -// --- Persistence: close and reopen --- - -func TestSQLiteRegistry_Persistence_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "registry.db") - - // Phase 1: write data - r1, err := NewSQLiteRegistry(dbPath) - require.NoError(t, err) - - now := time.Now().UTC().Truncate(time.Microsecond) - _ = r1.Register(AgentInfo{ - ID: "agent-1", - Name: "Persistent", - Capabilities: []string{"go", "rust"}, - Status: AgentBusy, - LastHeartbeat: now, - CurrentLoad: 3, - MaxLoad: 10, - }) - require.NoError(t, r1.Close()) - - // Phase 2: reopen and verify - r2, err := NewSQLiteRegistry(dbPath) - require.NoError(t, err) - defer func() { _ = r2.Close() }() - - got, err := r2.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "Persistent", got.Name) - assert.Equal(t, []string{"go", "rust"}, got.Capabilities) - assert.Equal(t, AgentBusy, got.Status) - assert.Equal(t, 3, got.CurrentLoad) - assert.Equal(t, 10, got.MaxLoad) - assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) -} - -// --- Constructor error case --- - -func TestNewSQLiteRegistry_Bad_InvalidPath(t *testing.T) { - _, err := NewSQLiteRegistry("/nonexistent/deeply/nested/dir/registry.db") - require.Error(t, err) -} - -// --- Config-based factory --- - -func TestNewAgentRegistryFromConfig_Good_Memory(t *testing.T) { - cfg := RegistryConfig{RegistryBackend: "memory"} - reg, err := NewAgentRegistryFromConfig(cfg) - require.NoError(t, err) - _, ok := reg.(*MemoryRegistry) - assert.True(t, ok, "expected MemoryRegistry") -} - -func TestNewAgentRegistryFromConfig_Good_Default(t *testing.T) { - cfg := RegistryConfig{} // empty defaults to memory - reg, err := NewAgentRegistryFromConfig(cfg) - require.NoError(t, err) - _, ok := reg.(*MemoryRegistry) - assert.True(t, ok, "expected MemoryRegistry for empty config") -} - -func TestNewAgentRegistryFromConfig_Good_SQLite(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "factory-registry.db") - cfg := RegistryConfig{ - RegistryBackend: "sqlite", - RegistryPath: dbPath, - } - reg, err := NewAgentRegistryFromConfig(cfg) - require.NoError(t, err) - sr, ok := reg.(*SQLiteRegistry) - assert.True(t, ok, "expected SQLiteRegistry") - _ = sr.Close() -} - -func TestNewAgentRegistryFromConfig_Bad_UnknownBackend(t *testing.T) { - cfg := RegistryConfig{RegistryBackend: "cassandra"} - _, err := NewAgentRegistryFromConfig(cfg) - require.Error(t, err) - assert.Contains(t, err.Error(), "unsupported registry backend") -} diff --git a/pkg/lifecycle/registry_test.go b/pkg/lifecycle/registry_test.go deleted file mode 100644 index 5318520..0000000 --- a/pkg/lifecycle/registry_test.go +++ /dev/null @@ -1,298 +0,0 @@ -package lifecycle - -import ( - "sort" - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Register tests --- - -func TestMemoryRegistry_Register_Good(t *testing.T) { - reg := NewMemoryRegistry() - err := reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Test Agent", - Capabilities: []string{"go", "testing"}, - Status: AgentAvailable, - MaxLoad: 5, - }) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "agent-1", got.ID) - assert.Equal(t, "Test Agent", got.Name) - assert.Equal(t, []string{"go", "testing"}, got.Capabilities) - assert.Equal(t, AgentAvailable, got.Status) - assert.Equal(t, 5, got.MaxLoad) -} - -func TestMemoryRegistry_Register_Good_Overwrite(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3}) - err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10}) - require.NoError(t, err) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, "Updated", got.Name) - assert.Equal(t, 10, got.MaxLoad) -} - -func TestMemoryRegistry_Register_Bad_EmptyID(t *testing.T) { - reg := NewMemoryRegistry() - err := reg.Register(AgentInfo{ID: "", Name: "No ID"}) - require.Error(t, err) - assert.Contains(t, err.Error(), "agent ID is required") -} - -func TestMemoryRegistry_Register_Good_CopySemantics(t *testing.T) { - reg := NewMemoryRegistry() - agent := AgentInfo{ - ID: "agent-1", - Name: "Copy Test", - Capabilities: []string{"go"}, - Status: AgentAvailable, - } - _ = reg.Register(agent) - - // Mutate the original — should not affect the stored copy. - agent.Name = "Mutated" - agent.Capabilities[0] = "rust" - - got, _ := reg.Get("agent-1") - assert.Equal(t, "Copy Test", got.Name) - // Note: slice header is copied, but underlying array is shared. - // This is consistent with the MemoryStore pattern in allowance.go. -} - -// --- Deregister tests --- - -func TestMemoryRegistry_Deregister_Good(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"}) - - err := reg.Deregister("agent-1") - require.NoError(t, err) - - _, err = reg.Get("agent-1") - require.Error(t, err) -} - -func TestMemoryRegistry_Deregister_Bad_NotFound(t *testing.T) { - reg := NewMemoryRegistry() - err := reg.Deregister("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Get tests --- - -func TestMemoryRegistry_Get_Good(t *testing.T) { - reg := NewMemoryRegistry() - now := time.Now().UTC() - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Name: "Getter", - Status: AgentBusy, - CurrentLoad: 2, - MaxLoad: 5, - LastHeartbeat: now, - }) - - got, err := reg.Get("agent-1") - require.NoError(t, err) - assert.Equal(t, AgentBusy, got.Status) - assert.Equal(t, 2, got.CurrentLoad) - assert.Equal(t, now, got.LastHeartbeat) -} - -func TestMemoryRegistry_Get_Bad_NotFound(t *testing.T) { - reg := NewMemoryRegistry() - _, err := reg.Get("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -func TestMemoryRegistry_Get_Good_ReturnsCopy(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1}) - - got, _ := reg.Get("agent-1") - got.CurrentLoad = 99 - got.Name = "Tampered" - - // Re-read — should be unchanged. - again, _ := reg.Get("agent-1") - assert.Equal(t, "Original", again.Name) - assert.Equal(t, 1, again.CurrentLoad) -} - -// --- List tests --- - -func TestMemoryRegistry_List_Good_Empty(t *testing.T) { - reg := NewMemoryRegistry() - agents := reg.List() - assert.Empty(t, agents) -} - -func TestMemoryRegistry_List_Good_Multiple(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"}) - _ = reg.Register(AgentInfo{ID: "b", Name: "Beta"}) - _ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"}) - - agents := reg.List() - assert.Len(t, agents, 3) - - // Sort by ID for deterministic assertion. - sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID }) - assert.Equal(t, "a", agents[0].ID) - assert.Equal(t, "b", agents[1].ID) - assert.Equal(t, "c", agents[2].ID) -} - -// --- Heartbeat tests --- - -func TestMemoryRegistry_Heartbeat_Good(t *testing.T) { - reg := NewMemoryRegistry() - past := time.Now().UTC().Add(-5 * time.Minute) - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentAvailable, - LastHeartbeat: past, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.True(t, got.LastHeartbeat.After(past)) - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestMemoryRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentOffline, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestMemoryRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "agent-1", - Status: AgentBusy, - }) - - err := reg.Heartbeat("agent-1") - require.NoError(t, err) - - got, _ := reg.Get("agent-1") - assert.Equal(t, AgentBusy, got.Status) -} - -func TestMemoryRegistry_Heartbeat_Bad_NotFound(t *testing.T) { - reg := NewMemoryRegistry() - err := reg.Heartbeat("nonexistent") - require.Error(t, err) - assert.Contains(t, err.Error(), "agent not found") -} - -// --- Reap tests --- - -func TestMemoryRegistry_Reap_Good_StaleAgent(t *testing.T) { - reg := NewMemoryRegistry() - stale := time.Now().UTC().Add(-10 * time.Minute) - fresh := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale}) - _ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "stale-1") - - got, _ := reg.Get("stale-1") - assert.Equal(t, AgentOffline, got.Status) - - got, _ = reg.Get("fresh-1") - assert.Equal(t, AgentAvailable, got.Status) -} - -func TestMemoryRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) { - reg := NewMemoryRegistry() - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestMemoryRegistry_Reap_Good_NoStaleAgents(t *testing.T) { - reg := NewMemoryRegistry() - now := time.Now().UTC() - - _ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now}) - _ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now}) - - reaped := reg.Reap(5 * time.Minute) - assert.Empty(t, reaped) -} - -func TestMemoryRegistry_Reap_Good_BusyAgentReaped(t *testing.T) { - reg := NewMemoryRegistry() - stale := time.Now().UTC().Add(-10 * time.Minute) - - _ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale}) - - reaped := reg.Reap(5 * time.Minute) - assert.Len(t, reaped, 1) - assert.Contains(t, reaped, "busy-stale") - - got, _ := reg.Get("busy-stale") - assert.Equal(t, AgentOffline, got.Status) -} - -// --- Concurrent access --- - -func TestMemoryRegistry_Concurrent_Good(t *testing.T) { - reg := NewMemoryRegistry() - - var wg sync.WaitGroup - for i := range 20 { - wg.Add(1) - go func(n int) { - defer wg.Done() - id := "agent-" + string(rune('a'+n%5)) - _ = reg.Register(AgentInfo{ - ID: id, - Name: "Concurrent", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - }) - _, _ = reg.Get(id) - _ = reg.Heartbeat(id) - _ = reg.List() - _ = reg.Reap(1 * time.Minute) - }(i) - } - wg.Wait() - - // No race conditions — test passes under -race. - agents := reg.List() - assert.True(t, len(agents) > 0) -} diff --git a/pkg/lifecycle/router.go b/pkg/lifecycle/router.go deleted file mode 100644 index b7bc86c..0000000 --- a/pkg/lifecycle/router.go +++ /dev/null @@ -1,131 +0,0 @@ -package lifecycle - -import ( - "cmp" - "slices" - - coreerr "forge.lthn.ai/core/go-log" -) - -// ErrNoEligibleAgent is returned when no agent matches the task requirements. -var ErrNoEligibleAgent = coreerr.E("TaskRouter", "no eligible agent for task", nil) - -// TaskRouter selects an agent for a given task from a list of candidates. -type TaskRouter interface { - // Route picks the best agent for the task and returns its ID. - // Returns ErrNoEligibleAgent if no agent qualifies. - Route(task *Task, agents []AgentInfo) (string, error) -} - -// DefaultRouter implements TaskRouter with capability matching and load-based -// scoring. For critical priority tasks it picks the least-loaded agent directly. -type DefaultRouter struct{} - -// NewDefaultRouter creates a new DefaultRouter. -func NewDefaultRouter() *DefaultRouter { - return &DefaultRouter{} -} - -// Route selects the best agent for the task: -// 1. Filter by availability (Available, or Busy with capacity). -// 2. Filter by capabilities (task.Labels must be a subset of agent.Capabilities). -// 3. For critical tasks, pick the least-loaded agent. -// 4. For other tasks, score by load ratio and pick the highest-scored agent. -// 5. Ties are broken by agent ID (alphabetical) for determinism. -func (r *DefaultRouter) Route(task *Task, agents []AgentInfo) (string, error) { - eligible := r.filterEligible(task, agents) - if len(eligible) == 0 { - return "", ErrNoEligibleAgent - } - - if task.Priority == PriorityCritical { - return r.leastLoaded(eligible), nil - } - - return r.highestScored(eligible), nil -} - -// filterEligible returns agents that are available (or busy with room) and -// possess all required capabilities. -func (r *DefaultRouter) filterEligible(task *Task, agents []AgentInfo) []AgentInfo { - var result []AgentInfo - for _, a := range agents { - if !r.isAvailable(a) { - continue - } - if !r.hasCapabilities(a, task.Labels) { - continue - } - result = append(result, a) - } - return result -} - -// isAvailable returns true if the agent can accept work. -func (r *DefaultRouter) isAvailable(a AgentInfo) bool { - switch a.Status { - case AgentAvailable: - return true - case AgentBusy: - // Busy agents can still accept work if they have capacity. - return a.MaxLoad == 0 || a.CurrentLoad < a.MaxLoad - default: - return false - } -} - -// hasCapabilities checks that the agent has all required labels. If the task -// has no labels, any agent qualifies. -func (r *DefaultRouter) hasCapabilities(a AgentInfo, labels []string) bool { - if len(labels) == 0 { - return true - } - for _, label := range labels { - if !slices.Contains(a.Capabilities, label) { - return false - } - } - return true -} - -// leastLoaded picks the agent with the lowest CurrentLoad. Ties are broken by -// agent ID (alphabetical). -func (r *DefaultRouter) leastLoaded(agents []AgentInfo) string { - // Sort: lowest load first, then by ID for determinism. - slices.SortFunc(agents, func(a, b AgentInfo) int { - if a.CurrentLoad != b.CurrentLoad { - return cmp.Compare(a.CurrentLoad, b.CurrentLoad) - } - return cmp.Compare(a.ID, b.ID) - }) - return agents[0].ID -} - -// highestScored picks the agent with the highest availability score. -// Score = 1.0 - (CurrentLoad / MaxLoad). If MaxLoad is 0, score is 1.0. -// Ties are broken by agent ID (alphabetical). -func (r *DefaultRouter) highestScored(agents []AgentInfo) string { - type scored struct { - id string - score float64 - } - - scores := make([]scored, len(agents)) - for i, a := range agents { - s := 1.0 - if a.MaxLoad > 0 { - s = 1.0 - float64(a.CurrentLoad)/float64(a.MaxLoad) - } - scores[i] = scored{id: a.ID, score: s} - } - - // Sort: highest score first, then by ID for determinism. - slices.SortFunc(scores, func(a, b scored) int { - if a.score != b.score { - return cmp.Compare(b.score, a.score) // highest first - } - return cmp.Compare(a.id, b.id) - }) - - return scores[0].id -} diff --git a/pkg/lifecycle/router_test.go b/pkg/lifecycle/router_test.go deleted file mode 100644 index f4e07f7..0000000 --- a/pkg/lifecycle/router_test.go +++ /dev/null @@ -1,239 +0,0 @@ -package lifecycle - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func makeAgent(id string, status AgentStatus, caps []string, load, maxLoad int) AgentInfo { - return AgentInfo{ - ID: id, - Name: id, - Capabilities: caps, - Status: status, - LastHeartbeat: time.Now().UTC(), - CurrentLoad: load, - MaxLoad: maxLoad, - } -} - -// --- Capability matching --- - -func TestDefaultRouter_Route_Good_MatchesCapabilities(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Labels: []string{"go", "testing"}} - agents := []AgentInfo{ - makeAgent("agent-a", AgentAvailable, []string{"go", "testing", "frontend"}, 0, 5), - makeAgent("agent-b", AgentAvailable, []string{"python"}, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "agent-a", id) -} - -func TestDefaultRouter_Route_Good_NoLabelsMatchesAll(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Labels: nil} - agents := []AgentInfo{ - makeAgent("agent-a", AgentAvailable, []string{"go"}, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "agent-a", id) -} - -func TestDefaultRouter_Route_Good_EmptyLabelsMatchesAll(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Labels: []string{}} - agents := []AgentInfo{ - makeAgent("agent-a", AgentAvailable, nil, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "agent-a", id) -} - -// --- Availability filtering --- - -func TestDefaultRouter_Route_Good_SkipsOfflineAgents(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - agents := []AgentInfo{ - makeAgent("offline-1", AgentOffline, nil, 0, 5), - makeAgent("online-1", AgentAvailable, nil, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "online-1", id) -} - -func TestDefaultRouter_Route_Good_BusyWithCapacity(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - agents := []AgentInfo{ - makeAgent("busy-1", AgentBusy, nil, 2, 5), // has capacity - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "busy-1", id) -} - -func TestDefaultRouter_Route_Good_BusyUnlimited(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - agents := []AgentInfo{ - makeAgent("busy-unlimited", AgentBusy, nil, 10, 0), // MaxLoad 0 = unlimited - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "busy-unlimited", id) -} - -func TestDefaultRouter_Route_Bad_BusyAtCapacity(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - agents := []AgentInfo{ - makeAgent("full-1", AgentBusy, nil, 5, 5), // at capacity - } - - _, err := router.Route(task, agents) - require.ErrorIs(t, err, ErrNoEligibleAgent) -} - -func TestDefaultRouter_Route_Bad_NoAgents(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - - _, err := router.Route(task, nil) - require.ErrorIs(t, err, ErrNoEligibleAgent) -} - -func TestDefaultRouter_Route_Bad_NoCapableAgent(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Labels: []string{"rust"}} - agents := []AgentInfo{ - makeAgent("go-agent", AgentAvailable, []string{"go"}, 0, 5), - makeAgent("py-agent", AgentAvailable, []string{"python"}, 0, 5), - } - - _, err := router.Route(task, agents) - require.ErrorIs(t, err, ErrNoEligibleAgent) -} - -func TestDefaultRouter_Route_Bad_AllOffline(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1"} - agents := []AgentInfo{ - makeAgent("off-1", AgentOffline, nil, 0, 5), - makeAgent("off-2", AgentOffline, nil, 0, 5), - } - - _, err := router.Route(task, agents) - require.ErrorIs(t, err, ErrNoEligibleAgent) -} - -// --- Load balancing --- - -func TestDefaultRouter_Route_Good_LeastLoaded(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Priority: PriorityMedium} - agents := []AgentInfo{ - makeAgent("agent-a", AgentAvailable, nil, 3, 10), - makeAgent("agent-b", AgentAvailable, nil, 1, 10), - makeAgent("agent-c", AgentAvailable, nil, 5, 10), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - // agent-b has score 0.9, agent-a has 0.7, agent-c has 0.5 - assert.Equal(t, "agent-b", id) -} - -func TestDefaultRouter_Route_Good_UnlimitedGetsMaxScore(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Priority: PriorityLow} - agents := []AgentInfo{ - makeAgent("limited", AgentAvailable, nil, 1, 10), // score 0.9 - makeAgent("unlimited", AgentAvailable, nil, 5, 0), // score 1.0 - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "unlimited", id) -} - -// --- Critical priority --- - -func TestDefaultRouter_Route_Good_CriticalPicksLeastLoaded(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Priority: PriorityCritical} - agents := []AgentInfo{ - makeAgent("agent-a", AgentAvailable, nil, 4, 10), - makeAgent("agent-b", AgentAvailable, nil, 1, 5), // lowest absolute load - makeAgent("agent-c", AgentAvailable, nil, 2, 10), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - // Critical: picks least loaded by CurrentLoad, not by ratio. - assert.Equal(t, "agent-b", id) -} - -// --- Tie-breaking --- - -func TestDefaultRouter_Route_Good_TieBreakByID(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Priority: PriorityMedium} - agents := []AgentInfo{ - makeAgent("charlie", AgentAvailable, nil, 0, 5), - makeAgent("alpha", AgentAvailable, nil, 0, 5), - makeAgent("bravo", AgentAvailable, nil, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "alpha", id) -} - -func TestDefaultRouter_Route_Good_CriticalTieBreakByID(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Priority: PriorityCritical} - agents := []AgentInfo{ - makeAgent("charlie", AgentAvailable, nil, 0, 5), - makeAgent("alpha", AgentAvailable, nil, 0, 5), - makeAgent("bravo", AgentAvailable, nil, 0, 5), - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - assert.Equal(t, "alpha", id) -} - -// --- Mixed scenarios --- - -func TestDefaultRouter_Route_Good_MixedStatusAndCapabilities(t *testing.T) { - router := NewDefaultRouter() - task := &Task{ID: "t1", Labels: []string{"go"}, Priority: PriorityHigh} - agents := []AgentInfo{ - makeAgent("offline-go", AgentOffline, []string{"go"}, 0, 5), - makeAgent("busy-py", AgentBusy, []string{"python"}, 1, 5), - makeAgent("busy-go-full", AgentBusy, []string{"go"}, 5, 5), // at capacity - makeAgent("busy-go-room", AgentBusy, []string{"go"}, 2, 5), // has room - makeAgent("avail-go", AgentAvailable, []string{"go"}, 1, 5), // available - } - - id, err := router.Route(task, agents) - require.NoError(t, err) - // avail-go: score = 1.0 - 1/5 = 0.8 - // busy-go-room: score = 1.0 - 2/5 = 0.6 - assert.Equal(t, "avail-go", id) -} diff --git a/pkg/lifecycle/score.go b/pkg/lifecycle/score.go deleted file mode 100644 index 7a09673..0000000 --- a/pkg/lifecycle/score.go +++ /dev/null @@ -1,147 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "net/http" - - "forge.lthn.ai/core/go-log" -) - -// ScoreContentRequest is the payload for content scoring. -type ScoreContentRequest struct { - Text string `json:"text"` - Prompt string `json:"prompt,omitempty"` -} - -// ScoreImprintRequest is the payload for linguistic imprint analysis. -type ScoreImprintRequest struct { - Text string `json:"text"` -} - -// ScoreResult holds the response from the scoring engine. -// The shape is proxied from the EaaS Go binary, so fields are dynamic. -type ScoreResult map[string]any - -// ScoreHealthResponse holds the health check result. -type ScoreHealthResponse struct { - Status string `json:"status"` - UpstreamStatus int `json:"upstream_status,omitempty"` -} - -// ScoreContent scores text for AI patterns via POST /v1/score/content. -func (c *Client) ScoreContent(ctx context.Context, req ScoreContentRequest) (ScoreResult, error) { - const op = "agentic.Client.ScoreContent" - - if req.Text == "" { - return nil, log.E(op, "text is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/score/content" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result ScoreResult - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return result, nil -} - -// ScoreImprint performs linguistic imprint analysis via POST /v1/score/imprint. -func (c *Client) ScoreImprint(ctx context.Context, req ScoreImprintRequest) (ScoreResult, error) { - const op = "agentic.Client.ScoreImprint" - - if req.Text == "" { - return nil, log.E(op, "text is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/score/imprint" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result ScoreResult - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return result, nil -} - -// ScoreHealth checks the scoring engine health via GET /v1/score/health. -// This endpoint does not require authentication. -func (c *Client) ScoreHealth(ctx context.Context) (*ScoreHealthResponse, error) { - const op = "agentic.Client.ScoreHealth" - - endpoint := c.BaseURL + "/v1/score/health" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - - // Health endpoint is unauthenticated but we still set headers for consistency. - httpReq.Header.Set("Accept", "application/json") - httpReq.Header.Set("User-Agent", "core-agentic-client/1.0") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result ScoreHealthResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} diff --git a/pkg/lifecycle/score_test.go b/pkg/lifecycle/score_test.go deleted file mode 100644 index c9305ec..0000000 --- a/pkg/lifecycle/score_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestClient_ScoreContent_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/v1/score/content", r.URL.Path) - assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - - var req ScoreContentRequest - err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) - assert.Contains(t, req.Text, "sample text for scoring") - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "score": 0.23, - "confidence": 0.91, - "label": "human", - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreContent(context.Background(), ScoreContentRequest{ - Text: "This is some sample text for scoring that is at least twenty characters", - }) - - require.NoError(t, err) - assert.InDelta(t, 0.23, result["score"], 0.001) - assert.Equal(t, "human", result["label"]) -} - -func TestClient_ScoreContent_Good_WithPrompt(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var req ScoreContentRequest - err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) - assert.Equal(t, "Check for formality", req.Prompt) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{"score": 0.5}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreContent(context.Background(), ScoreContentRequest{ - Text: "This text should be checked for formality and style patterns", - Prompt: "Check for formality", - }) - - require.NoError(t, err) - assert.InDelta(t, 0.5, result["score"], 0.001) -} - -func TestClient_ScoreContent_Bad_EmptyText(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - result, err := client.ScoreContent(context.Background(), ScoreContentRequest{}) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "text is required") -} - -func TestClient_ScoreContent_Bad_ServiceDown(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - _ = json.NewEncoder(w).Encode(map[string]any{ - "error": "scoring_unavailable", - "message": "Could not reach the scoring service.", - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreContent(context.Background(), ScoreContentRequest{ - Text: "This text needs at least twenty characters to validate", - }) - - assert.Error(t, err) - assert.Nil(t, result) -} - -func TestClient_ScoreImprint_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/v1/score/imprint", r.URL.Path) - - var req ScoreImprintRequest - err := json.NewDecoder(r.Body).Decode(&req) - require.NoError(t, err) - assert.NotEmpty(t, req.Text) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{ - "imprint": "abc123def456", - "confidence": 0.88, - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreImprint(context.Background(), ScoreImprintRequest{ - Text: "This text has a distinct linguistic imprint pattern to analyse", - }) - - require.NoError(t, err) - assert.Equal(t, "abc123def456", result["imprint"]) -} - -func TestClient_ScoreImprint_Bad_EmptyText(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - result, err := client.ScoreImprint(context.Background(), ScoreImprintRequest{}) - - assert.Error(t, err) - assert.Nil(t, result) - assert.Contains(t, err.Error(), "text is required") -} - -func TestClient_ScoreHealth_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodGet, r.Method) - assert.Equal(t, "/v1/score/health", r.URL.Path) - // Health endpoint should not require auth token - assert.Empty(t, r.Header.Get("Authorization")) - - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(ScoreHealthResponse{ - Status: "healthy", - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreHealth(context.Background()) - - require.NoError(t, err) - assert.Equal(t, "healthy", result.Status) -} - -func TestClient_ScoreHealth_Bad_Unhealthy(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadGateway) - _ = json.NewEncoder(w).Encode(ScoreHealthResponse{ - Status: "unhealthy", - UpstreamStatus: 503, - }) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - result, err := client.ScoreHealth(context.Background()) - - assert.Error(t, err) - assert.Nil(t, result) -} diff --git a/pkg/lifecycle/service.go b/pkg/lifecycle/service.go deleted file mode 100644 index 9aa8fc5..0000000 --- a/pkg/lifecycle/service.go +++ /dev/null @@ -1,142 +0,0 @@ -package lifecycle - -import ( - "context" - "os" - "os/exec" - "strings" - - "forge.lthn.ai/core/go/pkg/core" - "forge.lthn.ai/core/go-log" -) - -// Tasks for AI service - -// TaskCommit requests Claude to create a commit. -type TaskCommit struct { - Path string - Name string - CanEdit bool // allow Write/Edit tools -} - -// TaskPrompt sends a custom prompt to Claude. -type TaskPrompt struct { - Prompt string - WorkDir string - AllowedTools []string - - taskID string -} - -func (t *TaskPrompt) SetTaskID(id string) { t.taskID = id } -func (t *TaskPrompt) GetTaskID() string { return t.taskID } - -// ServiceOptions for configuring the AI service. -type ServiceOptions struct { - DefaultTools []string - AllowEdit bool // global permission for Write/Edit tools -} - -// DefaultServiceOptions returns sensible defaults. -func DefaultServiceOptions() ServiceOptions { - return ServiceOptions{ - DefaultTools: []string{"Bash", "Read", "Glob", "Grep"}, - AllowEdit: false, - } -} - -// Service provides AI/Claude operations as a Core service. -type Service struct { - *core.ServiceRuntime[ServiceOptions] -} - -// NewService creates an AI service factory. -func NewService(opts ServiceOptions) func(*core.Core) (any, error) { - return func(c *core.Core) (any, error) { - return &Service{ - ServiceRuntime: core.NewServiceRuntime(c, opts), - }, nil - } -} - -// OnStartup registers task handlers. -func (s *Service) OnStartup(ctx context.Context) error { - s.Core().RegisterTask(s.handleTask) - return nil -} - -func (s *Service) handleTask(c *core.Core, t core.Task) (any, bool, error) { - switch m := t.(type) { - case TaskCommit: - err := s.doCommit(m) - if err != nil { - log.Error("agentic: commit task failed", "err", err, "path", m.Path) - } - return nil, true, err - - case TaskPrompt: - err := s.doPrompt(m) - if err != nil { - log.Error("agentic: prompt task failed", "err", err) - } - return nil, true, err - } - return nil, false, nil -} - -func (s *Service) doCommit(task TaskCommit) error { - prompt := Prompt("commit") - - tools := []string{"Bash", "Read", "Glob", "Grep"} - if task.CanEdit { - tools = []string{"Bash", "Read", "Write", "Edit", "Glob", "Grep"} - } - - cmd := exec.CommandContext(context.Background(), "claude", "-p", prompt, "--allowedTools", strings.Join(tools, ",")) - cmd.Dir = task.Path - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Stdin = os.Stdin - - return cmd.Run() -} - -func (s *Service) doPrompt(task TaskPrompt) error { - if task.taskID != "" { - s.Core().Progress(task.taskID, 0.1, "Starting Claude...", &task) - } - - opts := s.Opts() - tools := opts.DefaultTools - if len(tools) == 0 { - tools = []string{"Bash", "Read", "Glob", "Grep"} - } - - if len(task.AllowedTools) > 0 { - tools = task.AllowedTools - } - - cmd := exec.CommandContext(context.Background(), "claude", "-p", task.Prompt, "--allowedTools", strings.Join(tools, ",")) - if task.WorkDir != "" { - cmd.Dir = task.WorkDir - } - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Stdin = os.Stdin - - if task.taskID != "" { - s.Core().Progress(task.taskID, 0.5, "Running Claude prompt...", &task) - } - - err := cmd.Run() - - if task.taskID != "" { - if err != nil { - s.Core().Progress(task.taskID, 1.0, "Failed: "+err.Error(), &task) - } else { - s.Core().Progress(task.taskID, 1.0, "Completed", &task) - } - } - - return err -} diff --git a/pkg/lifecycle/service_test.go b/pkg/lifecycle/service_test.go deleted file mode 100644 index 9f0f571..0000000 --- a/pkg/lifecycle/service_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package lifecycle - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDefaultServiceOptions_Good(t *testing.T) { - opts := DefaultServiceOptions() - - assert.Equal(t, []string{"Bash", "Read", "Glob", "Grep"}, opts.DefaultTools) - assert.False(t, opts.AllowEdit, "default should not allow edit") -} - -func TestTaskPrompt_SetGetTaskID(t *testing.T) { - tp := &TaskPrompt{ - Prompt: "test prompt", - WorkDir: "/tmp", - } - - assert.Empty(t, tp.GetTaskID(), "should start empty") - - tp.SetTaskID("task-abc-123") - assert.Equal(t, "task-abc-123", tp.GetTaskID()) - - tp.SetTaskID("task-def-456") - assert.Equal(t, "task-def-456", tp.GetTaskID(), "should allow overwriting") -} - -func TestTaskPrompt_SetGetTaskID_Empty(t *testing.T) { - tp := &TaskPrompt{} - - tp.SetTaskID("") - assert.Empty(t, tp.GetTaskID()) -} - -func TestTaskCommit_Fields(t *testing.T) { - tc := TaskCommit{ - Path: "/home/user/project", - Name: "test-commit", - CanEdit: true, - } - - assert.Equal(t, "/home/user/project", tc.Path) - assert.Equal(t, "test-commit", tc.Name) - assert.True(t, tc.CanEdit) -} - -func TestTaskCommit_DefaultCanEdit(t *testing.T) { - tc := TaskCommit{ - Path: "/tmp", - Name: "no-edit", - } - - assert.False(t, tc.CanEdit, "default should be false") -} - -func TestServiceOptions_CustomTools(t *testing.T) { - opts := ServiceOptions{ - DefaultTools: []string{"Bash", "Read", "Write", "Edit"}, - AllowEdit: true, - } - - assert.Len(t, opts.DefaultTools, 4) - assert.True(t, opts.AllowEdit) -} - -func TestTaskPrompt_AllFields(t *testing.T) { - tp := TaskPrompt{ - Prompt: "Refactor the authentication module", - WorkDir: "/home/user/project", - AllowedTools: []string{"Bash", "Read", "Edit"}, - } - - assert.Equal(t, "Refactor the authentication module", tp.Prompt) - assert.Equal(t, "/home/user/project", tp.WorkDir) - assert.Equal(t, []string{"Bash", "Read", "Edit"}, tp.AllowedTools) -} diff --git a/pkg/lifecycle/sessions.go b/pkg/lifecycle/sessions.go deleted file mode 100644 index 341f0a1..0000000 --- a/pkg/lifecycle/sessions.go +++ /dev/null @@ -1,287 +0,0 @@ -package lifecycle - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/url" - "strconv" - - "forge.lthn.ai/core/go-log" -) - -// SessionStatus represents the state of a session. -type SessionStatus string - -const ( - SessionActive SessionStatus = "active" - SessionPaused SessionStatus = "paused" - SessionCompleted SessionStatus = "completed" - SessionFailed SessionStatus = "failed" -) - -// Session represents an agent session from the PHP API. -type Session struct { - SessionID string `json:"session_id"` - AgentType string `json:"agent_type"` - Status SessionStatus `json:"status"` - PlanSlug string `json:"plan_slug,omitempty"` - Plan string `json:"plan,omitempty"` - Duration string `json:"duration,omitempty"` - StartedAt string `json:"started_at,omitempty"` - LastActiveAt string `json:"last_active_at,omitempty"` - EndedAt string `json:"ended_at,omitempty"` - ActionCount int `json:"action_count,omitempty"` - ArtifactCount int `json:"artifact_count,omitempty"` - ContextSummary map[string]any `json:"context_summary,omitempty"` - HandoffNotes string `json:"handoff_notes,omitempty"` - ContinuedFrom string `json:"continued_from,omitempty"` -} - -// StartSessionRequest is the payload for starting a new session. -type StartSessionRequest struct { - AgentType string `json:"agent_type"` - PlanSlug string `json:"plan_slug,omitempty"` - Context map[string]any `json:"context,omitempty"` -} - -// EndSessionRequest is the payload for ending a session. -type EndSessionRequest struct { - Status string `json:"status"` - Summary string `json:"summary,omitempty"` -} - -// ListSessionOptions specifies filters for listing sessions. -type ListSessionOptions struct { - Status SessionStatus `json:"status,omitempty"` - PlanSlug string `json:"plan_slug,omitempty"` - Limit int `json:"limit,omitempty"` -} - -// sessionListResponse wraps the list endpoint response. -type sessionListResponse struct { - Sessions []Session `json:"sessions"` - Total int `json:"total"` -} - -// sessionStartResponse wraps the session create endpoint response. -type sessionStartResponse struct { - SessionID string `json:"session_id"` - AgentType string `json:"agent_type"` - Plan string `json:"plan,omitempty"` - Status string `json:"status"` -} - -// sessionEndResponse wraps the session end endpoint response. -type sessionEndResponse struct { - SessionID string `json:"session_id"` - Status string `json:"status"` - Duration string `json:"duration,omitempty"` -} - -// sessionContinueResponse wraps the session continue endpoint response. -type sessionContinueResponse struct { - SessionID string `json:"session_id"` - AgentType string `json:"agent_type"` - Plan string `json:"plan,omitempty"` - Status string `json:"status"` - ContinuedFrom string `json:"continued_from,omitempty"` -} - -// ListSessions retrieves sessions matching the given options. -func (c *Client) ListSessions(ctx context.Context, opts ListSessionOptions) ([]Session, error) { - const op = "agentic.Client.ListSessions" - - params := url.Values{} - if opts.Status != "" { - params.Set("status", string(opts.Status)) - } - if opts.PlanSlug != "" { - params.Set("plan_slug", opts.PlanSlug) - } - if opts.Limit > 0 { - params.Set("limit", strconv.Itoa(opts.Limit)) - } - - endpoint := c.BaseURL + "/v1/sessions" - if len(params) > 0 { - endpoint += "?" + params.Encode() - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result sessionListResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return result.Sessions, nil -} - -// GetSession retrieves a session by ID. -func (c *Client) GetSession(ctx context.Context, sessionID string) (*Session, error) { - const op = "agentic.Client.GetSession" - - if sessionID == "" { - return nil, log.E(op, "session ID is required", nil) - } - - endpoint := fmt.Sprintf("%s/v1/sessions/%s", c.BaseURL, url.PathEscape(sessionID)) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var session Session - if err := json.NewDecoder(resp.Body).Decode(&session); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &session, nil -} - -// StartSession starts a new agent session. -func (c *Client) StartSession(ctx context.Context, req StartSessionRequest) (*sessionStartResponse, error) { - const op = "agentic.Client.StartSession" - - if req.AgentType == "" { - return nil, log.E(op, "agent_type is required", nil) - } - - data, err := json.Marshal(req) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := c.BaseURL + "/v1/sessions" - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(httpReq) - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(httpReq) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result sessionStartResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} - -// EndSession ends a session with a final status and optional summary. -func (c *Client) EndSession(ctx context.Context, sessionID string, status string, summary string) error { - const op = "agentic.Client.EndSession" - - if sessionID == "" { - return log.E(op, "session ID is required", nil) - } - if status == "" { - return log.E(op, "status is required", nil) - } - - payload := EndSessionRequest{Status: status, Summary: summary} - data, err := json.Marshal(payload) - if err != nil { - return log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/sessions/%s/end", c.BaseURL, url.PathEscape(sessionID)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - return c.checkResponse(resp) -} - -// ContinueSession creates a new session continuing from a previous one (multi-agent handoff). -func (c *Client) ContinueSession(ctx context.Context, previousSessionID, agentType string) (*sessionContinueResponse, error) { - const op = "agentic.Client.ContinueSession" - - if previousSessionID == "" { - return nil, log.E(op, "previous session ID is required", nil) - } - if agentType == "" { - return nil, log.E(op, "agent_type is required", nil) - } - - data, err := json.Marshal(map[string]string{"agent_type": agentType}) - if err != nil { - return nil, log.E(op, "failed to marshal request", err) - } - - endpoint := fmt.Sprintf("%s/v1/sessions/%s/continue", c.BaseURL, url.PathEscape(previousSessionID)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(data)) - if err != nil { - return nil, log.E(op, "failed to create request", err) - } - c.setHeaders(req) - req.Header.Set("Content-Type", "application/json") - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, log.E(op, "request failed", err) - } - defer func() { _ = resp.Body.Close() }() - - if err := c.checkResponse(resp); err != nil { - return nil, log.E(op, "API error", err) - } - - var result sessionContinueResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, log.E(op, "failed to decode response", err) - } - - return &result, nil -} diff --git a/pkg/lifecycle/status.go b/pkg/lifecycle/status.go deleted file mode 100644 index bd317f3..0000000 --- a/pkg/lifecycle/status.go +++ /dev/null @@ -1,137 +0,0 @@ -package lifecycle - -import ( - "cmp" - "context" - "fmt" - "slices" - "strings" - - "forge.lthn.ai/core/go-log" -) - -// StatusSummary aggregates status from the agent registry, task client, and -// allowance service for CLI display. -type StatusSummary struct { - // Agents is the list of registered agents. - Agents []AgentInfo - // PendingTasks is the count of tasks with StatusPending. - PendingTasks int - // InProgressTasks is the count of tasks with StatusInProgress. - InProgressTasks int - // AllowanceRemaining maps agent ID to remaining daily tokens. -1 means unlimited. - AllowanceRemaining map[string]int64 -} - -// GetStatus aggregates status from the registry, client, and allowance service. -// Any of registry, client, or allowanceSvc can be nil -- those sections are -// simply skipped. Returns what we can collect without failing on nil components. -func GetStatus(ctx context.Context, registry AgentRegistry, client *Client, allowanceSvc *AllowanceService) (*StatusSummary, error) { - const op = "agentic.GetStatus" - - summary := &StatusSummary{ - AllowanceRemaining: make(map[string]int64), - } - - // Collect agents from registry. - if registry != nil { - summary.Agents = registry.List() - } - - // Count tasks by status via client. - if client != nil { - pending, err := client.ListTasks(ctx, ListOptions{Status: StatusPending}) - if err != nil { - return nil, log.E(op, "failed to list pending tasks", err) - } - summary.PendingTasks = len(pending) - - inProgress, err := client.ListTasks(ctx, ListOptions{Status: StatusInProgress}) - if err != nil { - return nil, log.E(op, "failed to list in-progress tasks", err) - } - summary.InProgressTasks = len(inProgress) - } - - // Collect allowance remaining per agent. - if allowanceSvc != nil { - for _, agent := range summary.Agents { - check, err := allowanceSvc.Check(agent.ID, "") - if err != nil { - // Skip agents whose allowance cannot be resolved. - continue - } - summary.AllowanceRemaining[agent.ID] = check.RemainingTokens - } - } - - return summary, nil -} - -// FormatStatus renders the summary as a human-readable table string suitable -// for CLI output. -func FormatStatus(s *StatusSummary) string { - var b strings.Builder - - // Count agents by status. - available := 0 - busy := 0 - for _, a := range s.Agents { - switch a.Status { - case AgentAvailable: - available++ - case AgentBusy: - busy++ - } - } - - total := len(s.Agents) - statusParts := make([]string, 0, 2) - if available > 0 { - statusParts = append(statusParts, fmt.Sprintf("%d available", available)) - } - if busy > 0 { - statusParts = append(statusParts, fmt.Sprintf("%d busy", busy)) - } - offline := total - available - busy - if offline > 0 { - statusParts = append(statusParts, fmt.Sprintf("%d offline", offline)) - } - - if len(statusParts) > 0 { - fmt.Fprintf(&b, "Agents: %d (%s)\n", total, strings.Join(statusParts, ", ")) - } else { - fmt.Fprintf(&b, "Agents: %d\n", total) - } - - fmt.Fprintf(&b, "Tasks: %d pending, %d in progress\n", s.PendingTasks, s.InProgressTasks) - - if len(s.Agents) > 0 { - // Sort agents by ID for deterministic output. - agents := slices.Clone(s.Agents) - slices.SortFunc(agents, func(a, b AgentInfo) int { - return cmp.Compare(a.ID, b.ID) - }) - - fmt.Fprintf(&b, "%-16s%-12s%-8s%s\n", "Agent", "Status", "Load", "Remaining") - for _, a := range agents { - load := fmt.Sprintf("%d/%d", a.CurrentLoad, a.MaxLoad) - if a.MaxLoad == 0 { - load = fmt.Sprintf("%d/-", a.CurrentLoad) - } - - remaining := "unknown" - if tokens, ok := s.AllowanceRemaining[a.ID]; ok { - if tokens < 0 { - remaining = "unlimited" - } else { - remaining = fmt.Sprintf("%d tokens", tokens) - } - } - - fmt.Fprintf(&b, "%-16s%-12s%-8s%s\n", a.ID, string(a.Status), load, remaining) - } - } - - return b.String() -} diff --git a/pkg/lifecycle/status_test.go b/pkg/lifecycle/status_test.go deleted file mode 100644 index c16f854..0000000 --- a/pkg/lifecycle/status_test.go +++ /dev/null @@ -1,270 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- GetStatus tests --- - -func TestGetStatus_Good_AllNil(t *testing.T) { - summary, err := GetStatus(context.Background(), nil, nil, nil) - require.NoError(t, err) - assert.Empty(t, summary.Agents) - assert.Equal(t, 0, summary.PendingTasks) - assert.Equal(t, 0, summary.InProgressTasks) - assert.Empty(t, summary.AllowanceRemaining) -} - -func TestGetStatus_Good_RegistryOnly(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "virgil", - Name: "Virgil", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - MaxLoad: 5, - }) - _ = reg.Register(AgentInfo{ - ID: "charon", - Name: "Charon", - Status: AgentBusy, - CurrentLoad: 3, - MaxLoad: 5, - LastHeartbeat: time.Now().UTC(), - }) - - summary, err := GetStatus(context.Background(), reg, nil, nil) - require.NoError(t, err) - assert.Len(t, summary.Agents, 2) - assert.Equal(t, 0, summary.PendingTasks) - assert.Equal(t, 0, summary.InProgressTasks) -} - -func TestGetStatus_Good_FullSummary(t *testing.T) { - // Set up mock server returning task counts. - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - status := r.URL.Query().Get("status") - w.Header().Set("Content-Type", "application/json") - switch status { - case "pending": - tasks := []Task{ - {ID: "t1", Status: StatusPending}, - {ID: "t2", Status: StatusPending}, - {ID: "t3", Status: StatusPending}, - } - _ = json.NewEncoder(w).Encode(tasks) - case "in_progress": - tasks := []Task{ - {ID: "t4", Status: StatusInProgress}, - } - _ = json.NewEncoder(w).Encode(tasks) - default: - _ = json.NewEncoder(w).Encode([]Task{}) - } - })) - defer server.Close() - - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "virgil", - Name: "Virgil", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - MaxLoad: 5, - }) - _ = reg.Register(AgentInfo{ - ID: "charon", - Name: "Charon", - Status: AgentBusy, - CurrentLoad: 3, - MaxLoad: 5, - LastHeartbeat: time.Now().UTC(), - }) - - store := NewMemoryStore() - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "virgil", - DailyTokenLimit: 50000, - }) - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "charon", - DailyTokenLimit: 50000, - }) - // Simulate charon has used 38000 tokens. - _ = store.IncrementUsage("charon", 38000, 0) - - svc := NewAllowanceService(store) - client := NewClient(server.URL, "test-token") - - summary, err := GetStatus(context.Background(), reg, client, svc) - require.NoError(t, err) - assert.Len(t, summary.Agents, 2) - assert.Equal(t, 3, summary.PendingTasks) - assert.Equal(t, 1, summary.InProgressTasks) - assert.Equal(t, int64(50000), summary.AllowanceRemaining["virgil"]) - assert.Equal(t, int64(12000), summary.AllowanceRemaining["charon"]) -} - -func TestGetStatus_Good_UnlimitedAllowance(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "darbs", - Name: "Darbs", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - MaxLoad: 3, - }) - - store := NewMemoryStore() - // DailyTokenLimit 0 means unlimited. - _ = store.SetAllowance(&AgentAllowance{ - AgentID: "darbs", - DailyTokenLimit: 0, - }) - svc := NewAllowanceService(store) - - summary, err := GetStatus(context.Background(), reg, nil, svc) - require.NoError(t, err) - // Unlimited: Check returns RemainingTokens = -1. - assert.Equal(t, int64(-1), summary.AllowanceRemaining["darbs"]) -} - -func TestGetStatus_Good_AllowanceSkipsUnknownAgents(t *testing.T) { - reg := NewMemoryRegistry() - _ = reg.Register(AgentInfo{ - ID: "unknown-agent", - Name: "Unknown", - Status: AgentAvailable, - LastHeartbeat: time.Now().UTC(), - }) - - store := NewMemoryStore() - // No allowance set for "unknown-agent" -- GetAllowance will error. - svc := NewAllowanceService(store) - - summary, err := GetStatus(context.Background(), reg, nil, svc) - require.NoError(t, err) - // AllowanceRemaining should not have an entry for unknown-agent. - _, exists := summary.AllowanceRemaining["unknown-agent"] - assert.False(t, exists) -} - -func TestGetStatus_Bad_ClientError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "server error"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - summary, err := GetStatus(context.Background(), nil, client, nil) - assert.Error(t, err) - assert.Nil(t, summary) - assert.Contains(t, err.Error(), "pending tasks") -} - -// --- FormatStatus tests --- - -func TestFormatStatus_Good_Empty(t *testing.T) { - s := &StatusSummary{ - AllowanceRemaining: make(map[string]int64), - } - output := FormatStatus(s) - assert.Contains(t, output, "Agents: 0") - assert.Contains(t, output, "Tasks: 0 pending, 0 in progress") - // No agent table rows when there are no agents — only the summary lines. - assert.NotContains(t, output, "Status") -} - -func TestFormatStatus_Good_FullTable(t *testing.T) { - s := &StatusSummary{ - Agents: []AgentInfo{ - {ID: "virgil", Status: AgentAvailable, CurrentLoad: 0, MaxLoad: 5}, - {ID: "charon", Status: AgentBusy, CurrentLoad: 3, MaxLoad: 5}, - {ID: "darbs", Status: AgentAvailable, CurrentLoad: 0, MaxLoad: 3}, - }, - PendingTasks: 5, - InProgressTasks: 2, - AllowanceRemaining: map[string]int64{ - "virgil": 45000, - "charon": 12000, - "darbs": -1, - }, - } - - output := FormatStatus(s) - assert.Contains(t, output, "Agents: 3 (2 available, 1 busy)") - assert.Contains(t, output, "Tasks: 5 pending, 2 in progress") - assert.Contains(t, output, "virgil") - assert.Contains(t, output, "available") - assert.Contains(t, output, "45000 tokens") - assert.Contains(t, output, "charon") - assert.Contains(t, output, "busy") - assert.Contains(t, output, "12000 tokens") - assert.Contains(t, output, "darbs") - assert.Contains(t, output, "unlimited") - - // Verify deterministic sort order (agents sorted by ID). - lines := strings.Split(output, "\n") - var agentLines []string - for _, line := range lines { - if strings.HasPrefix(line, "charon") || strings.HasPrefix(line, "darbs") || strings.HasPrefix(line, "virgil") { - agentLines = append(agentLines, line) - } - } - require.Len(t, agentLines, 3) - assert.True(t, strings.HasPrefix(agentLines[0], "charon")) - assert.True(t, strings.HasPrefix(agentLines[1], "darbs")) - assert.True(t, strings.HasPrefix(agentLines[2], "virgil")) -} - -func TestFormatStatus_Good_OfflineAgent(t *testing.T) { - s := &StatusSummary{ - Agents: []AgentInfo{ - {ID: "offline-bot", Status: AgentOffline, CurrentLoad: 0, MaxLoad: 5}, - }, - AllowanceRemaining: map[string]int64{ - "offline-bot": 30000, - }, - } - - output := FormatStatus(s) - assert.Contains(t, output, "1 offline") - assert.Contains(t, output, "offline-bot") -} - -func TestFormatStatus_Good_UnlimitedMaxLoad(t *testing.T) { - s := &StatusSummary{ - Agents: []AgentInfo{ - {ID: "unlimited", Status: AgentAvailable, CurrentLoad: 2, MaxLoad: 0}, - }, - AllowanceRemaining: map[string]int64{ - "unlimited": -1, - }, - } - - output := FormatStatus(s) - assert.Contains(t, output, "2/-") - assert.Contains(t, output, "unlimited") -} - -func TestFormatStatus_Good_UnknownAllowance(t *testing.T) { - s := &StatusSummary{ - Agents: []AgentInfo{ - {ID: "mystery", Status: AgentAvailable, MaxLoad: 5}, - }, - AllowanceRemaining: make(map[string]int64), - } - - output := FormatStatus(s) - assert.Contains(t, output, "unknown") -} diff --git a/pkg/lifecycle/submit.go b/pkg/lifecycle/submit.go deleted file mode 100644 index 09fb99c..0000000 --- a/pkg/lifecycle/submit.go +++ /dev/null @@ -1,35 +0,0 @@ -package lifecycle - -import ( - "context" - "time" - - "forge.lthn.ai/core/go-log" -) - -// SubmitTask creates a new task with the given parameters via the API client. -// It validates that title is non-empty, sets CreatedAt to the current time, -// and delegates creation to client.CreateTask. -func SubmitTask(ctx context.Context, client *Client, title, description string, labels []string, priority TaskPriority) (*Task, error) { - const op = "agentic.SubmitTask" - - if title == "" { - return nil, log.E(op, "title is required", nil) - } - - task := Task{ - Title: title, - Description: description, - Labels: labels, - Priority: priority, - Status: StatusPending, - CreatedAt: time.Now().UTC(), - } - - created, err := client.CreateTask(ctx, task) - if err != nil { - return nil, log.E(op, "failed to create task", err) - } - - return created, nil -} diff --git a/pkg/lifecycle/submit_test.go b/pkg/lifecycle/submit_test.go deleted file mode 100644 index 6b5676c..0000000 --- a/pkg/lifecycle/submit_test.go +++ /dev/null @@ -1,134 +0,0 @@ -package lifecycle - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Client.CreateTask tests --- - -func TestClient_CreateTask_Good(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "/api/tasks", r.URL.Path) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) - - var task Task - err := json.NewDecoder(r.Body).Decode(&task) - require.NoError(t, err) - assert.Equal(t, "New feature", task.Title) - assert.Equal(t, PriorityHigh, task.Priority) - - // Return the task with an assigned ID. - task.ID = "task-new-1" - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task := Task{ - Title: "New feature", - Description: "Build something great", - Priority: PriorityHigh, - Labels: []string{"feature"}, - Status: StatusPending, - } - - created, err := client.CreateTask(context.Background(), task) - require.NoError(t, err) - assert.Equal(t, "task-new-1", created.ID) - assert.Equal(t, "New feature", created.Title) - assert.Equal(t, PriorityHigh, created.Priority) -} - -func TestClient_CreateTask_Bad_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - _ = json.NewEncoder(w).Encode(APIError{Message: "validation failed"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - task := Task{Title: "Bad task"} - - created, err := client.CreateTask(context.Background(), task) - assert.Error(t, err) - assert.Nil(t, created) - assert.Contains(t, err.Error(), "validation failed") -} - -// --- SubmitTask tests --- - -func TestSubmitTask_Good_AllFields(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var task Task - err := json.NewDecoder(r.Body).Decode(&task) - require.NoError(t, err) - assert.Equal(t, "Implement login", task.Title) - assert.Equal(t, "OAuth2 login flow", task.Description) - assert.Equal(t, []string{"auth", "frontend"}, task.Labels) - assert.Equal(t, PriorityHigh, task.Priority) - assert.Equal(t, StatusPending, task.Status) - assert.False(t, task.CreatedAt.IsZero()) - - task.ID = "task-submit-1" - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - created, err := SubmitTask(context.Background(), client, "Implement login", "OAuth2 login flow", []string{"auth", "frontend"}, PriorityHigh) - require.NoError(t, err) - assert.Equal(t, "task-submit-1", created.ID) - assert.Equal(t, "Implement login", created.Title) -} - -func TestSubmitTask_Good_MinimalFields(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - var task Task - _ = json.NewDecoder(r.Body).Decode(&task) - task.ID = "task-minimal" - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(task) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - created, err := SubmitTask(context.Background(), client, "Simple task", "", nil, PriorityLow) - require.NoError(t, err) - assert.Equal(t, "task-minimal", created.ID) -} - -func TestSubmitTask_Bad_EmptyTitle(t *testing.T) { - client := NewClient("https://api.example.com", "test-token") - created, err := SubmitTask(context.Background(), client, "", "description", nil, PriorityMedium) - assert.Error(t, err) - assert.Nil(t, created) - assert.Contains(t, err.Error(), "title is required") -} - -func TestSubmitTask_Bad_ClientError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(APIError{Message: "internal error"}) - })) - defer server.Close() - - client := NewClient(server.URL, "test-token") - created, err := SubmitTask(context.Background(), client, "Good title", "", nil, PriorityMedium) - assert.Error(t, err) - assert.Nil(t, created) - assert.Contains(t, err.Error(), "create task") -} diff --git a/pkg/lifecycle/types.go b/pkg/lifecycle/types.go deleted file mode 100644 index bb2e7bd..0000000 --- a/pkg/lifecycle/types.go +++ /dev/null @@ -1,150 +0,0 @@ -// Package agentic provides an API client for core-agentic, an AI-assisted task -// management service. It enables developers and AI agents to discover, claim, -// and complete development tasks. -package lifecycle - -import ( - "time" -) - -// TaskStatus represents the state of a task in the system. -type TaskStatus string - -const ( - // StatusPending indicates the task is available to be claimed. - StatusPending TaskStatus = "pending" - // StatusInProgress indicates the task has been claimed and is being worked on. - StatusInProgress TaskStatus = "in_progress" - // StatusCompleted indicates the task has been successfully completed. - StatusCompleted TaskStatus = "completed" - // StatusBlocked indicates the task cannot proceed due to dependencies. - StatusBlocked TaskStatus = "blocked" - // StatusFailed indicates the task has exceeded its retry limit and been dead-lettered. - StatusFailed TaskStatus = "failed" -) - -// TaskPriority represents the urgency level of a task. -type TaskPriority string - -const ( - // PriorityCritical indicates the task requires immediate attention. - PriorityCritical TaskPriority = "critical" - // PriorityHigh indicates the task is important and should be addressed soon. - PriorityHigh TaskPriority = "high" - // PriorityMedium indicates the task has normal priority. - PriorityMedium TaskPriority = "medium" - // PriorityLow indicates the task can be addressed when time permits. - PriorityLow TaskPriority = "low" -) - -// Task represents a development task in the core-agentic system. -type Task struct { - // ID is the unique identifier for the task. - ID string `json:"id"` - // Title is the short description of the task. - Title string `json:"title"` - // Description provides detailed information about what needs to be done. - Description string `json:"description"` - // Priority indicates the urgency of the task. - Priority TaskPriority `json:"priority"` - // Status indicates the current state of the task. - Status TaskStatus `json:"status"` - // Labels are tags used to categorize the task. - Labels []string `json:"labels,omitempty"` - // Files lists the files that are relevant to this task. - Files []string `json:"files,omitempty"` - // CreatedAt is when the task was created. - CreatedAt time.Time `json:"created_at"` - // UpdatedAt is when the task was last modified. - UpdatedAt time.Time `json:"updated_at"` - // ClaimedBy is the identifier of the agent or developer who claimed the task. - ClaimedBy string `json:"claimed_by,omitempty"` - // ClaimedAt is when the task was claimed. - ClaimedAt *time.Time `json:"claimed_at,omitempty"` - // Project is the project this task belongs to. - Project string `json:"project,omitempty"` - // Dependencies lists task IDs that must be completed before this task. - Dependencies []string `json:"dependencies,omitempty"` - // Blockers lists task IDs that this task is blocking. - Blockers []string `json:"blockers,omitempty"` - // MaxRetries is the maximum dispatch attempts before dead-lettering. 0 uses DefaultMaxRetries. - MaxRetries int `json:"max_retries,omitempty"` - // RetryCount is the number of failed dispatch attempts so far. - RetryCount int `json:"retry_count,omitempty"` - // LastAttempt is when the last dispatch attempt occurred. - LastAttempt *time.Time `json:"last_attempt,omitempty"` - // FailReason explains why the task was moved to failed status. - FailReason string `json:"fail_reason,omitempty"` -} - -// TaskUpdate contains fields that can be updated on a task. -type TaskUpdate struct { - // Status is the new status for the task. - Status TaskStatus `json:"status,omitempty"` - // Progress is a percentage (0-100) indicating completion. - Progress int `json:"progress,omitempty"` - // Notes are additional comments about the update. - Notes string `json:"notes,omitempty"` -} - -// TaskResult contains the outcome of a completed task. -type TaskResult struct { - // Success indicates whether the task was completed successfully. - Success bool `json:"success"` - // Output is the result or summary of the completed work. - Output string `json:"output,omitempty"` - // Artifacts are files or resources produced by the task. - Artifacts []string `json:"artifacts,omitempty"` - // ErrorMessage contains details if the task failed. - ErrorMessage string `json:"error_message,omitempty"` -} - -// ListOptions specifies filters for listing tasks. -type ListOptions struct { - // Status filters tasks by their current status. - Status TaskStatus `json:"status,omitempty"` - // Labels filters tasks that have all specified labels. - Labels []string `json:"labels,omitempty"` - // Priority filters tasks by priority level. - Priority TaskPriority `json:"priority,omitempty"` - // Limit is the maximum number of tasks to return. - Limit int `json:"limit,omitempty"` - // Project filters tasks by project. - Project string `json:"project,omitempty"` - // ClaimedBy filters tasks claimed by a specific agent. - ClaimedBy string `json:"claimed_by,omitempty"` -} - -// APIError represents an error response from the API. -type APIError struct { - // Code is the HTTP status code. - Code int `json:"code"` - // Message is the error description. - Message string `json:"message"` - // Details provides additional context about the error. - Details string `json:"details,omitempty"` -} - -// Error implements the error interface for APIError. -func (e *APIError) Error() string { - if e.Details != "" { - return e.Message + ": " + e.Details - } - return e.Message -} - -// ClaimResponse is returned when a task is successfully claimed. -type ClaimResponse struct { - // Task is the claimed task with updated fields. - Task *Task `json:"task"` - // Message provides additional context about the claim. - Message string `json:"message,omitempty"` -} - -// CompleteResponse is returned when a task is completed. -type CompleteResponse struct { - // Task is the completed task with final status. - Task *Task `json:"task"` - // Message provides additional context about the completion. - Message string `json:"message,omitempty"` -} diff --git a/pkg/loop/engine.go b/pkg/loop/engine.go deleted file mode 100644 index 563219b..0000000 --- a/pkg/loop/engine.go +++ /dev/null @@ -1,132 +0,0 @@ -package loop - -import ( - "context" - "fmt" - "strings" - - "forge.lthn.ai/core/go-inference" - coreerr "forge.lthn.ai/core/go-log" -) - -// Engine drives the agent loop: prompt the model, parse tool calls, execute -// tools, feed results back, and repeat until the model responds without tool -// blocks or the turn limit is reached. -type Engine struct { - model inference.TextModel - tools []Tool - system string - maxTurns int -} - -// Option configures an Engine. -type Option func(*Engine) - -// WithModel sets the inference backend for the engine. -func WithModel(m inference.TextModel) Option { - return func(e *Engine) { e.model = m } -} - -// WithTools registers tools that the model may invoke. -func WithTools(tools ...Tool) Option { - return func(e *Engine) { e.tools = append(e.tools, tools...) } -} - -// WithSystem overrides the default system prompt. When empty, BuildSystemPrompt -// generates one from the registered tools. -func WithSystem(prompt string) Option { - return func(e *Engine) { e.system = prompt } -} - -// WithMaxTurns caps the number of LLM calls before the loop errors out. -func WithMaxTurns(n int) Option { - return func(e *Engine) { e.maxTurns = n } -} - -// New creates an Engine with the given options. The default turn limit is 10. -func New(opts ...Option) *Engine { - e := &Engine{maxTurns: 10} - for _, o := range opts { - o(e) - } - return e -} - -// Run executes the agent loop. It sends userMessage to the model, parses any -// tool calls from the response, executes them, appends the results, and loops -// until the model produces a response with no tool blocks or maxTurns is hit. -func (e *Engine) Run(ctx context.Context, userMessage string) (*Result, error) { - if e.model == nil { - return nil, coreerr.E("loop.Run", "no model configured", nil) - } - - system := e.system - if system == "" { - system = BuildSystemPrompt(e.tools) - } - - handlers := make(map[string]func(context.Context, map[string]any) (string, error), len(e.tools)) - for _, tool := range e.tools { - handlers[tool.Name] = tool.Handler - } - - var history []Message - history = append(history, Message{Role: RoleUser, Content: userMessage}) - - for turn := 0; turn < e.maxTurns; turn++ { - if err := ctx.Err(); err != nil { - return nil, coreerr.E("loop.Run", "context cancelled", err) - } - - prompt := BuildFullPrompt(system, history, "") - var response strings.Builder - for tok := range e.model.Generate(ctx, prompt) { - response.WriteString(tok.Text) - } - if err := e.model.Err(); err != nil { - return nil, coreerr.E("loop.Run", "inference error", err) - } - - fullResponse := response.String() - calls, cleanText := ParseToolCalls(fullResponse) - - history = append(history, Message{ - Role: RoleAssistant, - Content: fullResponse, - ToolUses: calls, - }) - - // No tool calls means the model has produced a final answer. - if len(calls) == 0 { - return &Result{ - Response: cleanText, - Messages: history, - Turns: turn + 1, - }, nil - } - - // Execute each tool call and append results to the history. - for _, call := range calls { - handler, ok := handlers[call.Name] - var resultText string - if !ok { - resultText = fmt.Sprintf("error: unknown tool %q", call.Name) - } else { - out, err := handler(ctx, call.Args) - if err != nil { - resultText = fmt.Sprintf("error: %v", err) - } else { - resultText = out - } - } - - history = append(history, Message{ - Role: RoleToolResult, - Content: resultText, - ToolUses: []ToolUse{{Name: call.Name}}, - }) - } - } - - return nil, coreerr.E("loop.Run", fmt.Sprintf("max turns (%d) exceeded", e.maxTurns), nil) -} diff --git a/pkg/loop/engine_test.go b/pkg/loop/engine_test.go deleted file mode 100644 index a0520c6..0000000 --- a/pkg/loop/engine_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package loop - -import ( - "context" - "iter" - "testing" - "time" - - "forge.lthn.ai/core/go-inference" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// mockModel returns canned responses. Each call to Generate pops the next response. -type mockModel struct { - responses []string - callCount int - lastErr error -} - -func (m *mockModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { - return func(yield func(inference.Token) bool) { - if m.callCount >= len(m.responses) { - return - } - resp := m.responses[m.callCount] - m.callCount++ - for i, ch := range resp { - if !yield(inference.Token{ID: int32(i), Text: string(ch)}) { - return - } - } - } -} - -func (m *mockModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { - return m.Generate(ctx, "", opts...) -} - -func (m *mockModel) Err() error { return m.lastErr } -func (m *mockModel) Close() error { return nil } -func (m *mockModel) ModelType() string { return "mock" } -func (m *mockModel) Info() inference.ModelInfo { return inference.ModelInfo{} } -func (m *mockModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } -func (m *mockModel) Classify(ctx context.Context, p []string, o ...inference.GenerateOption) ([]inference.ClassifyResult, error) { - return nil, nil -} -func (m *mockModel) BatchGenerate(ctx context.Context, p []string, o ...inference.GenerateOption) ([]inference.BatchResult, error) { - return nil, nil -} - -func TestEngine_Good_SimpleResponse(t *testing.T) { - model := &mockModel{responses: []string{"Hello, I can help you."}} - engine := New(WithModel(model), WithMaxTurns(5)) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - result, err := engine.Run(ctx, "hi") - require.NoError(t, err) - assert.Equal(t, "Hello, I can help you.", result.Response) - assert.Equal(t, 1, result.Turns) - assert.Len(t, result.Messages, 2) // user + assistant -} - -func TestEngine_Good_ToolCallAndResponse(t *testing.T) { - model := &mockModel{responses: []string{ - "Let me check.\n```tool\n{\"name\": \"test_tool\", \"args\": {\"key\": \"val\"}}\n```\n", - "The result was: tool output.", - }} - - toolCalled := false - tools := []Tool{{ - Name: "test_tool", - Description: "A test tool", - Parameters: map[string]any{"type": "object"}, - Handler: func(ctx context.Context, args map[string]any) (string, error) { - toolCalled = true - assert.Equal(t, "val", args["key"]) - return "tool output", nil - }, - }} - - engine := New(WithModel(model), WithTools(tools...), WithMaxTurns(5)) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - result, err := engine.Run(ctx, "do something") - require.NoError(t, err) - assert.True(t, toolCalled) - assert.Equal(t, 2, result.Turns) - assert.Contains(t, result.Response, "tool output") -} - -func TestEngine_Bad_MaxTurnsExceeded(t *testing.T) { - model := &mockModel{responses: []string{ - "```tool\n{\"name\": \"t\", \"args\": {}}\n```\n", - "```tool\n{\"name\": \"t\", \"args\": {}}\n```\n", - "```tool\n{\"name\": \"t\", \"args\": {}}\n```\n", - }} - - tools := []Tool{{ - Name: "t", Description: "loop forever", - Handler: func(ctx context.Context, args map[string]any) (string, error) { - return "ok", nil - }, - }} - - engine := New(WithModel(model), WithTools(tools...), WithMaxTurns(2)) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - _, err := engine.Run(ctx, "go") - require.Error(t, err) - assert.Contains(t, err.Error(), "max turns") -} - -func TestEngine_Bad_ContextCancelled(t *testing.T) { - model := &mockModel{responses: []string{"thinking..."}} - engine := New(WithModel(model), WithMaxTurns(5)) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // cancel immediately - - _, err := engine.Run(ctx, "hi") - require.Error(t, err) -} diff --git a/pkg/loop/parse.go b/pkg/loop/parse.go deleted file mode 100644 index 28af2a4..0000000 --- a/pkg/loop/parse.go +++ /dev/null @@ -1,49 +0,0 @@ -package loop - -import ( - "encoding/json" - "regexp" - "strings" -) - -var toolBlockRe = regexp.MustCompile("(?s)```tool\\s*\n(.*?)\\s*```") - -// ParseToolCalls extracts tool invocations from fenced ```tool blocks in -// model output. Only blocks tagged "tool" are matched; other fenced blocks -// (```go, ```json, etc.) pass through untouched. Malformed JSON is silently -// skipped. Returns the parsed calls and the cleaned text with tool blocks -// removed. -func ParseToolCalls(output string) ([]ToolUse, string) { - matches := toolBlockRe.FindAllStringSubmatchIndex(output, -1) - if len(matches) == 0 { - return nil, output - } - - var calls []ToolUse - cleaned := output - - // Walk matches in reverse so index arithmetic stays valid after each splice. - for i := len(matches) - 1; i >= 0; i-- { - m := matches[i] - fullStart, fullEnd := m[0], m[1] - bodyStart, bodyEnd := m[2], m[3] - - body := strings.TrimSpace(output[bodyStart:bodyEnd]) - if body == "" { - cleaned = cleaned[:fullStart] + cleaned[fullEnd:] - continue - } - - var call ToolUse - if err := json.Unmarshal([]byte(body), &call); err != nil { - cleaned = cleaned[:fullStart] + cleaned[fullEnd:] - continue - } - - calls = append([]ToolUse{call}, calls...) - cleaned = cleaned[:fullStart] + cleaned[fullEnd:] - } - - cleaned = strings.TrimSpace(cleaned) - return calls, cleaned -} diff --git a/pkg/loop/parse_test.go b/pkg/loop/parse_test.go deleted file mode 100644 index d8444e7..0000000 --- a/pkg/loop/parse_test.go +++ /dev/null @@ -1,70 +0,0 @@ -package loop - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseTool_Good_SingleCall(t *testing.T) { - input := "Let me read that file.\n```tool\n{\"name\": \"file_read\", \"args\": {\"path\": \"/tmp/test.txt\"}}\n```\n" - calls, text := ParseToolCalls(input) - require.Len(t, calls, 1) - assert.Equal(t, "file_read", calls[0].Name) - assert.Equal(t, "/tmp/test.txt", calls[0].Args["path"]) - assert.Contains(t, text, "Let me read that file.") - assert.NotContains(t, text, "```tool") -} - -func TestParseTool_Good_MultipleCalls(t *testing.T) { - input := "I'll check both.\n```tool\n{\"name\": \"file_read\", \"args\": {\"path\": \"a.txt\"}}\n```\nAnd also:\n```tool\n{\"name\": \"file_read\", \"args\": {\"path\": \"b.txt\"}}\n```\n" - calls, _ := ParseToolCalls(input) - require.Len(t, calls, 2) - assert.Equal(t, "a.txt", calls[0].Args["path"]) - assert.Equal(t, "b.txt", calls[1].Args["path"]) -} - -func TestParseTool_Good_NoToolCalls(t *testing.T) { - input := "Here is a normal response with no tool calls." - calls, text := ParseToolCalls(input) - assert.Empty(t, calls) - assert.Equal(t, input, text) -} - -func TestParseTool_Bad_MalformedJSON(t *testing.T) { - input := "```tool\n{not valid json}\n```\n" - calls, _ := ParseToolCalls(input) - assert.Empty(t, calls) -} - -func TestParseTool_Good_WithSurroundingText(t *testing.T) { - input := "Before text.\n```tool\n{\"name\": \"test\", \"args\": {}}\n```\nAfter text." - calls, text := ParseToolCalls(input) - require.Len(t, calls, 1) - assert.Contains(t, text, "Before text.") - assert.Contains(t, text, "After text.") -} - -func TestParseTool_Ugly_NestedBackticks(t *testing.T) { - input := "```go\nfmt.Println(\"hello\")\n```\n```tool\n{\"name\": \"test\", \"args\": {}}\n```\n" - calls, text := ParseToolCalls(input) - require.Len(t, calls, 1) - assert.Equal(t, "test", calls[0].Name) - assert.Contains(t, text, "```go") -} - -func TestParseTool_Bad_EmptyToolBlock(t *testing.T) { - input := "```tool\n\n```\n" - calls, _ := ParseToolCalls(input) - assert.Empty(t, calls) -} - -func TestParseTool_Good_ArgsWithNestedObject(t *testing.T) { - input := "```tool\n{\"name\": \"complex\", \"args\": {\"config\": {\"key\": \"value\", \"num\": 42}}}\n```\n" - calls, _ := ParseToolCalls(input) - require.Len(t, calls, 1) - config, ok := calls[0].Args["config"].(map[string]any) - require.True(t, ok) - assert.Equal(t, "value", config["key"]) -} diff --git a/pkg/loop/prompt.go b/pkg/loop/prompt.go deleted file mode 100644 index d83163e..0000000 --- a/pkg/loop/prompt.go +++ /dev/null @@ -1,71 +0,0 @@ -package loop - -import ( - "encoding/json" - "fmt" - "strings" -) - -// BuildSystemPrompt constructs the system prompt that instructs the model how -// to use the available tools. When no tools are registered it returns a plain -// assistant preamble without tool-calling instructions. -func BuildSystemPrompt(tools []Tool) string { - if len(tools) == 0 { - return "You are a helpful assistant." - } - - var b strings.Builder - b.WriteString("You are a helpful assistant with access to the following tools:\n\n") - - for _, tool := range tools { - b.WriteString(fmt.Sprintf("### %s\n", tool.Name)) - b.WriteString(fmt.Sprintf("%s\n", tool.Description)) - if tool.Parameters != nil { - schema, _ := json.MarshalIndent(tool.Parameters, "", " ") - b.WriteString(fmt.Sprintf("Parameters: %s\n", schema)) - } - b.WriteString("\n") - } - - b.WriteString("To use a tool, output a fenced block:\n") - b.WriteString("```tool\n") - b.WriteString("{\"name\": \"tool_name\", \"args\": {\"key\": \"value\"}}\n") - b.WriteString("```\n\n") - b.WriteString("You may call multiple tools in one response. After tool results are provided, continue reasoning. When you have a final answer, respond normally without tool blocks.\n") - - return b.String() -} - -// BuildFullPrompt assembles the complete prompt string from the system prompt, -// conversation history, and current user message. Each message is tagged with -// its role so the model can distinguish turns. Tool results are annotated with -// the tool name for traceability. -func BuildFullPrompt(system string, history []Message, userMessage string) string { - var b strings.Builder - - if system != "" { - b.WriteString(system) - b.WriteString("\n\n") - } - - for _, msg := range history { - switch msg.Role { - case RoleUser: - b.WriteString(fmt.Sprintf("[user]\n%s\n\n", msg.Content)) - case RoleAssistant: - b.WriteString(fmt.Sprintf("[assistant]\n%s\n\n", msg.Content)) - case RoleToolResult: - toolName := "unknown" - if len(msg.ToolUses) > 0 { - toolName = msg.ToolUses[0].Name - } - b.WriteString(fmt.Sprintf("[tool_result: %s]\n%s\n\n", toolName, msg.Content)) - } - } - - if userMessage != "" { - b.WriteString(fmt.Sprintf("[user]\n%s\n\n", userMessage)) - } - - return b.String() -} diff --git a/pkg/loop/prompt_test.go b/pkg/loop/prompt_test.go deleted file mode 100644 index d14b320..0000000 --- a/pkg/loop/prompt_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package loop - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBuildSystemPrompt_Good_WithTools(t *testing.T) { - tools := []Tool{ - { - Name: "file_read", - Description: "Read a file", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{"type": "string"}, - }, - "required": []any{"path"}, - }, - }, - { - Name: "eaas_score", - Description: "Score text for AI content", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "text": map[string]any{"type": "string"}, - }, - }, - }, - } - - prompt := BuildSystemPrompt(tools) - assert.Contains(t, prompt, "file_read") - assert.Contains(t, prompt, "Read a file") - assert.Contains(t, prompt, "eaas_score") - assert.Contains(t, prompt, "```tool") -} - -func TestBuildSystemPrompt_Good_NoTools(t *testing.T) { - prompt := BuildSystemPrompt(nil) - assert.NotEmpty(t, prompt) - assert.NotContains(t, prompt, "```tool") -} - -func TestBuildFullPrompt_Good(t *testing.T) { - history := []Message{ - {Role: RoleUser, Content: "hello"}, - {Role: RoleAssistant, Content: "hi there"}, - } - prompt := BuildFullPrompt("system prompt", history, "what next?") - assert.Contains(t, prompt, "system prompt") - assert.Contains(t, prompt, "hello") - assert.Contains(t, prompt, "hi there") - assert.Contains(t, prompt, "what next?") -} - -func TestBuildFullPrompt_Good_IncludesToolResults(t *testing.T) { - history := []Message{ - {Role: RoleUser, Content: "read test.txt"}, - {Role: RoleAssistant, Content: "I'll read it.", ToolUses: []ToolUse{{Name: "file_read", Args: map[string]any{"path": "test.txt"}}}}, - {Role: RoleToolResult, Content: "file contents here", ToolUses: []ToolUse{{Name: "file_read"}}}, - } - prompt := BuildFullPrompt("", history, "") - assert.Contains(t, prompt, "[tool_result: file_read]") - assert.Contains(t, prompt, "file contents here") -} diff --git a/pkg/loop/tools_eaas.go b/pkg/loop/tools_eaas.go deleted file mode 100644 index b5042d1..0000000 --- a/pkg/loop/tools_eaas.go +++ /dev/null @@ -1,90 +0,0 @@ -package loop - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - coreerr "forge.lthn.ai/core/go-log" -) - -var eaasClient = &http.Client{Timeout: 30 * time.Second} - -// EaaSTools returns the three EaaS tool wrappers: score, imprint, and atlas similar. -func EaaSTools(baseURL string) []Tool { - return []Tool{ - { - Name: "eaas_score", - Description: "Score text for AI-generated content, sycophancy, and compliance markers. Returns verdict, LEK score, heuristic breakdown, and detected flags.", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "text": map[string]any{"type": "string", "description": "Text to analyse"}, - }, - "required": []any{"text"}, - }, - Handler: eaasPostHandler(baseURL, "/v1/score/content"), - }, - { - Name: "eaas_imprint", - Description: "Analyse the linguistic imprint of text. Returns stylistic fingerprint metrics.", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "text": map[string]any{"type": "string", "description": "Text to analyse"}, - }, - "required": []any{"text"}, - }, - Handler: eaasPostHandler(baseURL, "/v1/score/imprint"), - }, - { - Name: "eaas_similar", - Description: "Find similar previously scored content via atlas vector search.", - Parameters: map[string]any{ - "type": "object", - "properties": map[string]any{ - "id": map[string]any{"type": "string", "description": "Scoring ID to search from"}, - "limit": map[string]any{"type": "integer", "description": "Max results (default 5)"}, - }, - "required": []any{"id"}, - }, - Handler: eaasPostHandler(baseURL, "/v1/atlas/similar"), - }, - } -} - -func eaasPostHandler(baseURL, path string) func(context.Context, map[string]any) (string, error) { - return func(ctx context.Context, args map[string]any) (string, error) { - body, err := json.Marshal(args) - if err != nil { - return "", coreerr.E("eaas.handler", "marshal args", err) - } - - req, err := http.NewRequestWithContext(ctx, "POST", baseURL+path, bytes.NewReader(body)) - if err != nil { - return "", coreerr.E("eaas.handler", "create request", err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := eaasClient.Do(req) - if err != nil { - return "", coreerr.E("eaas.handler", "eaas request", err) - } - defer resp.Body.Close() - - result, err := io.ReadAll(resp.Body) - if err != nil { - return "", coreerr.E("eaas.handler", "read response", err) - } - - if resp.StatusCode != http.StatusOK { - return "", coreerr.E("eaas.handler", fmt.Sprintf("eaas returned %d: %s", resp.StatusCode, string(result)), nil) - } - - return string(result), nil - } -} diff --git a/pkg/loop/tools_eaas_test.go b/pkg/loop/tools_eaas_test.go deleted file mode 100644 index fba6d53..0000000 --- a/pkg/loop/tools_eaas_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package loop - -import ( - "context" - "encoding/json" - "net/http" - "net/http/httptest" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestEaaSTools_Good_ReturnsThreeTools(t *testing.T) { - tools := EaaSTools("http://localhost:8009") - assert.Len(t, tools, 3) - - names := make([]string, len(tools)) - for i, tool := range tools { - names[i] = tool.Name - } - assert.Contains(t, names, "eaas_score") - assert.Contains(t, names, "eaas_imprint") - assert.Contains(t, names, "eaas_similar") -} - -func TestEaaSScore_Good_CallsAPI(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/v1/score/content", r.URL.Path) - assert.Equal(t, "POST", r.Method) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]any{ - "verdict": "likely_human", - "lek": 85.5, - }) - })) - defer server.Close() - - tools := EaaSTools(server.URL) - var scoreTool Tool - for _, tool := range tools { - if tool.Name == "eaas_score" { - scoreTool = tool - break - } - } - - result, err := scoreTool.Handler(context.Background(), map[string]any{"text": "Hello world"}) - require.NoError(t, err) - assert.Contains(t, result, "likely_human") -} diff --git a/pkg/loop/tools_mcp.go b/pkg/loop/tools_mcp.go deleted file mode 100644 index 58e8f7e..0000000 --- a/pkg/loop/tools_mcp.go +++ /dev/null @@ -1,47 +0,0 @@ -package loop - -import ( - "context" - "encoding/json" - - coreerr "forge.lthn.ai/core/go-log" - aimcp "forge.lthn.ai/core/mcp/pkg/mcp" -) - -// LoadMCPTools converts all tools from a go-ai MCP Service into loop.Tool values. -func LoadMCPTools(svc *aimcp.Service) []Tool { - var tools []Tool - for _, record := range svc.Tools() { - tools = append(tools, Tool{ - Name: record.Name, - Description: record.Description, - Parameters: record.InputSchema, - Handler: WrapRESTHandler(RESTHandlerFunc(record.RESTHandler)), - }) - } - return tools -} - -// RESTHandlerFunc matches go-ai's mcp.RESTHandler signature. -type RESTHandlerFunc func(ctx context.Context, body []byte) (any, error) - -// WrapRESTHandler converts a go-ai RESTHandler into a loop.Tool handler. -func WrapRESTHandler(handler RESTHandlerFunc) func(context.Context, map[string]any) (string, error) { - return func(ctx context.Context, args map[string]any) (string, error) { - body, err := json.Marshal(args) - if err != nil { - return "", coreerr.E("mcp.handler", "marshal args", err) - } - - result, err := handler(ctx, body) - if err != nil { - return "", err - } - - out, err := json.Marshal(result) - if err != nil { - return "", coreerr.E("mcp.handler", "marshal result", err) - } - return string(out), nil - } -} diff --git a/pkg/loop/tools_mcp_test.go b/pkg/loop/tools_mcp_test.go deleted file mode 100644 index f9eb401..0000000 --- a/pkg/loop/tools_mcp_test.go +++ /dev/null @@ -1,51 +0,0 @@ -package loop - -import ( - "context" - "encoding/json" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestLoadMCPTools_Good_ConvertsRecords(t *testing.T) { - handler := func(ctx context.Context, args map[string]any) (string, error) { - return "result", nil - } - - tool := Tool{ - Name: "file_read", - Description: "Read a file", - Parameters: map[string]any{"type": "object"}, - Handler: handler, - } - - assert.Equal(t, "file_read", tool.Name) - result, err := tool.Handler(context.Background(), map[string]any{"path": "/tmp/test"}) - require.NoError(t, err) - assert.Equal(t, "result", result) -} - -func TestWrapRESTHandler_Good(t *testing.T) { - restHandler := func(ctx context.Context, body []byte) (any, error) { - var input map[string]any - json.Unmarshal(body, &input) - return map[string]string{"content": "hello from " + input["path"].(string)}, nil - } - - wrapped := WrapRESTHandler(restHandler) - result, err := wrapped(context.Background(), map[string]any{"path": "/tmp/test"}) - require.NoError(t, err) - assert.Contains(t, result, "hello from /tmp/test") -} - -func TestWrapRESTHandler_Bad_HandlerError(t *testing.T) { - restHandler := func(ctx context.Context, body []byte) (any, error) { - return nil, assert.AnError - } - - wrapped := WrapRESTHandler(restHandler) - _, err := wrapped(context.Background(), map[string]any{}) - require.Error(t, err) -} diff --git a/pkg/loop/types.go b/pkg/loop/types.go deleted file mode 100644 index 95b5a90..0000000 --- a/pkg/loop/types.go +++ /dev/null @@ -1,38 +0,0 @@ -package loop - -import "context" - -const ( - RoleUser = "user" - RoleAssistant = "assistant" - RoleToolResult = "tool_result" - RoleSystem = "system" -) - -// Message represents one turn in the conversation. -type Message struct { - Role string `json:"role"` - Content string `json:"content"` - ToolUses []ToolUse `json:"tool_uses,omitempty"` -} - -// ToolUse represents a parsed tool invocation from model output. -type ToolUse struct { - Name string `json:"name"` - Args map[string]any `json:"args"` -} - -// Tool describes an available tool the model can invoke. -type Tool struct { - Name string - Description string - Parameters map[string]any - Handler func(ctx context.Context, args map[string]any) (string, error) -} - -// Result is the final output after the loop completes. -type Result struct { - Response string // final text from the model (tool blocks stripped) - Messages []Message // full conversation history - Turns int // number of LLM calls made -} diff --git a/pkg/loop/types_test.go b/pkg/loop/types_test.go deleted file mode 100644 index 46ef905..0000000 --- a/pkg/loop/types_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package loop - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestMessage_Good_UserMessage(t *testing.T) { - m := Message{Role: RoleUser, Content: "hello"} - assert.Equal(t, RoleUser, m.Role) - assert.Equal(t, "hello", m.Content) - assert.Nil(t, m.ToolUses) -} - -func TestMessage_Good_AssistantWithTools(t *testing.T) { - m := Message{ - Role: RoleAssistant, - Content: "I'll read that file.", - ToolUses: []ToolUse{ - {Name: "file_read", Args: map[string]any{"path": "/tmp/test.txt"}}, - }, - } - assert.Len(t, m.ToolUses, 1) - assert.Equal(t, "file_read", m.ToolUses[0].Name) -} - -func TestTool_Good_HasHandler(t *testing.T) { - tool := Tool{ - Name: "test_tool", - Description: "A test tool", - Parameters: map[string]any{"type": "object"}, - } - assert.Equal(t, "test_tool", tool.Name) - assert.NotEmpty(t, tool.Description) -} - -func TestResult_Good_Fields(t *testing.T) { - r := Result{ - Response: "done", - Turns: 3, - } - assert.Equal(t, "done", r.Response) - assert.Equal(t, 3, r.Turns) -} diff --git a/pkg/orchestrator/clotho.go b/pkg/orchestrator/clotho.go deleted file mode 100644 index eddc4d0..0000000 --- a/pkg/orchestrator/clotho.go +++ /dev/null @@ -1,99 +0,0 @@ -package orchestrator - -import ( - "context" - "iter" - "strings" - - "forge.lthn.ai/core/agent/pkg/jobrunner" -) - -// RunMode determines the execution strategy for a dispatched task. -type RunMode string - -const ( - ModeStandard RunMode = "standard" - ModeDual RunMode = "dual" // The Clotho Protocol — dual-run verification -) - -// Spinner is the Clotho orchestrator that determines the fate of each task. -type Spinner struct { - Config ClothoConfig - Agents map[string]AgentConfig -} - -// NewSpinner creates a new Clotho orchestrator. -func NewSpinner(cfg ClothoConfig, agents map[string]AgentConfig) *Spinner { - return &Spinner{ - Config: cfg, - Agents: agents, - } -} - -// DeterminePlan decides if a signal requires dual-run verification based on -// the global strategy, agent configuration, and repository criticality. -func (s *Spinner) DeterminePlan(signal *jobrunner.PipelineSignal, agentName string) RunMode { - if s.Config.Strategy != "clotho-verified" { - return ModeStandard - } - - agent, ok := s.Agents[agentName] - if !ok { - return ModeStandard - } - if agent.DualRun { - return ModeDual - } - - // Protect critical repos with dual-run (Axiom 1). - if signal.RepoName == "core" || strings.Contains(signal.RepoName, "security") { - return ModeDual - } - - return ModeStandard -} - -// GetVerifierModel returns the model for the secondary "signed" verification run. -func (s *Spinner) GetVerifierModel(agentName string) string { - agent, ok := s.Agents[agentName] - if !ok || agent.VerifyModel == "" { - return "gemini-1.5-pro" - } - return agent.VerifyModel -} - -// Agents returns an iterator over the configured agents. -func (s *Spinner) AgentsSeq() iter.Seq2[string, AgentConfig] { - return func(yield func(string, AgentConfig) bool) { - for name, agent := range s.Agents { - if !yield(name, agent) { - return - } - } - } -} - -// FindByForgejoUser resolves a Forgejo username to the agent config key and config. -// This decouples agent naming (mythological roles) from Forgejo identity. -func (s *Spinner) FindByForgejoUser(forgejoUser string) (string, AgentConfig, bool) { - if forgejoUser == "" { - return "", AgentConfig{}, false - } - // Direct match on config key first. - if agent, ok := s.Agents[forgejoUser]; ok { - return forgejoUser, agent, true - } - // Search by ForgejoUser field. - for name, agent := range s.AgentsSeq() { - if agent.ForgejoUser != "" && agent.ForgejoUser == forgejoUser { - return name, agent, true - } - } - return "", AgentConfig{}, false -} - -// Weave compares primary and verifier outputs. Returns true if they converge. -// This is a placeholder for future semantic diff logic. -func (s *Spinner) Weave(ctx context.Context, primaryOutput, signedOutput []byte) (bool, error) { - return string(primaryOutput) == string(signedOutput), nil -} diff --git a/pkg/orchestrator/clotho_test.go b/pkg/orchestrator/clotho_test.go deleted file mode 100644 index 73ff354..0000000 --- a/pkg/orchestrator/clotho_test.go +++ /dev/null @@ -1,194 +0,0 @@ -package orchestrator - -import ( - "context" - "testing" - - "forge.lthn.ai/core/agent/pkg/jobrunner" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func newTestSpinner() *Spinner { - return NewSpinner( - ClothoConfig{ - Strategy: "clotho-verified", - ValidationThreshold: 0.85, - }, - map[string]AgentConfig{ - "claude-agent": { - Host: "claude@10.0.0.1", - Model: "opus", - Runner: "claude", - Active: true, - DualRun: false, - ForgejoUser: "claude-forge", - }, - "gemini-agent": { - Host: "localhost", - Model: "gemini-2.0-flash", - VerifyModel: "gemini-1.5-pro", - Runner: "gemini", - Active: true, - DualRun: true, - ForgejoUser: "gemini-forge", - }, - }, - ) -} - -func TestNewSpinner_Good(t *testing.T) { - spinner := newTestSpinner() - assert.NotNil(t, spinner) - assert.Equal(t, "clotho-verified", spinner.Config.Strategy) - assert.Len(t, spinner.Agents, 2) -} - -func TestDeterminePlan_Good_Standard(t *testing.T) { - spinner := newTestSpinner() - - signal := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "core-php", - } - - mode := spinner.DeterminePlan(signal, "claude-agent") - assert.Equal(t, ModeStandard, mode) -} - -func TestDeterminePlan_Good_DualRunByAgent(t *testing.T) { - spinner := newTestSpinner() - - signal := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: "some-repo", - } - - mode := spinner.DeterminePlan(signal, "gemini-agent") - assert.Equal(t, ModeDual, mode) -} - -func TestDeterminePlan_Good_DualRunByCriticalRepo(t *testing.T) { - spinner := newTestSpinner() - - tests := []struct { - name string - repoName string - expected RunMode - }{ - {name: "core repo", repoName: "core", expected: ModeDual}, - {name: "security repo", repoName: "auth-security", expected: ModeDual}, - {name: "security-audit", repoName: "security-audit", expected: ModeDual}, - {name: "regular repo", repoName: "docs", expected: ModeStandard}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - signal := &jobrunner.PipelineSignal{ - RepoOwner: "host-uk", - RepoName: tt.repoName, - } - mode := spinner.DeterminePlan(signal, "claude-agent") - assert.Equal(t, tt.expected, mode) - }) - } -} - -func TestDeterminePlan_Good_NonVerifiedStrategy(t *testing.T) { - spinner := NewSpinner( - ClothoConfig{Strategy: "direct"}, - map[string]AgentConfig{ - "agent": {Host: "localhost", DualRun: true, Active: true}, - }, - ) - - signal := &jobrunner.PipelineSignal{RepoName: "core"} - mode := spinner.DeterminePlan(signal, "agent") - assert.Equal(t, ModeStandard, mode, "non-verified strategy should always return standard") -} - -func TestDeterminePlan_Good_UnknownAgent(t *testing.T) { - spinner := newTestSpinner() - - signal := &jobrunner.PipelineSignal{RepoName: "some-repo"} - mode := spinner.DeterminePlan(signal, "nonexistent-agent") - assert.Equal(t, ModeStandard, mode, "unknown agent should return standard") -} - -func TestGetVerifierModel_Good(t *testing.T) { - spinner := newTestSpinner() - - model := spinner.GetVerifierModel("gemini-agent") - assert.Equal(t, "gemini-1.5-pro", model) -} - -func TestGetVerifierModel_Good_Default(t *testing.T) { - spinner := newTestSpinner() - - // claude-agent has no VerifyModel set. - model := spinner.GetVerifierModel("claude-agent") - assert.Equal(t, "gemini-1.5-pro", model, "should fall back to default") -} - -func TestGetVerifierModel_Good_UnknownAgent(t *testing.T) { - spinner := newTestSpinner() - - model := spinner.GetVerifierModel("unknown") - assert.Equal(t, "gemini-1.5-pro", model, "should fall back to default") -} - -func TestFindByForgejoUser_Good_DirectMatch(t *testing.T) { - spinner := newTestSpinner() - - // Direct match on config key. - name, agent, found := spinner.FindByForgejoUser("claude-agent") - assert.True(t, found) - assert.Equal(t, "claude-agent", name) - assert.Equal(t, "opus", agent.Model) -} - -func TestFindByForgejoUser_Good_ByField(t *testing.T) { - spinner := newTestSpinner() - - // Match by ForgejoUser field. - name, agent, found := spinner.FindByForgejoUser("claude-forge") - assert.True(t, found) - assert.Equal(t, "claude-agent", name) - assert.Equal(t, "opus", agent.Model) -} - -func TestFindByForgejoUser_Bad_NotFound(t *testing.T) { - spinner := newTestSpinner() - - _, _, found := spinner.FindByForgejoUser("nonexistent") - assert.False(t, found) -} - -func TestFindByForgejoUser_Bad_Empty(t *testing.T) { - spinner := newTestSpinner() - - _, _, found := spinner.FindByForgejoUser("") - assert.False(t, found) -} - -func TestWeave_Good_Matching(t *testing.T) { - spinner := newTestSpinner() - - converge, err := spinner.Weave(context.Background(), []byte("output"), []byte("output")) - require.NoError(t, err) - assert.True(t, converge) -} - -func TestWeave_Good_Diverging(t *testing.T) { - spinner := newTestSpinner() - - converge, err := spinner.Weave(context.Background(), []byte("primary"), []byte("different")) - require.NoError(t, err) - assert.False(t, converge) -} - -func TestRunModeConstants(t *testing.T) { - assert.Equal(t, RunMode("standard"), ModeStandard) - assert.Equal(t, RunMode("dual"), ModeDual) -} diff --git a/pkg/orchestrator/config.go b/pkg/orchestrator/config.go deleted file mode 100644 index 6c4cb88..0000000 --- a/pkg/orchestrator/config.go +++ /dev/null @@ -1,146 +0,0 @@ -// Package agentci provides configuration, security, and orchestration for AgentCI dispatch targets. -package orchestrator - -import ( - "maps" - "path/filepath" - - "forge.lthn.ai/core/agent/pkg/agentic" - "forge.lthn.ai/core/config" - coreerr "forge.lthn.ai/core/go-log" -) - -// AgentConfig represents a single agent machine in the config file. -type AgentConfig struct { - Host string `yaml:"host" mapstructure:"host"` - QueueDir string `yaml:"queue_dir" mapstructure:"queue_dir"` - ForgejoUser string `yaml:"forgejo_user" mapstructure:"forgejo_user"` - Model string `yaml:"model" mapstructure:"model"` // primary AI model - Runner string `yaml:"runner" mapstructure:"runner"` // runner binary: claude, codex, gemini - VerifyModel string `yaml:"verify_model" mapstructure:"verify_model"` // secondary model for dual-run - SecurityLevel string `yaml:"security_level" mapstructure:"security_level"` // low, high - Roles []string `yaml:"roles" mapstructure:"roles"` - DualRun bool `yaml:"dual_run" mapstructure:"dual_run"` - Active bool `yaml:"active" mapstructure:"active"` - ApiURL string `yaml:"api_url" mapstructure:"api_url"` // PHP agentic API base URL - ApiKey string `yaml:"api_key" mapstructure:"api_key"` // PHP agentic API key -} - -// ClothoConfig controls the orchestration strategy. -type ClothoConfig struct { - Strategy string `yaml:"strategy" mapstructure:"strategy"` // direct, clotho-verified - ValidationThreshold float64 `yaml:"validation_threshold" mapstructure:"validation_threshold"` // divergence limit (0.0-1.0) - SigningKeyPath string `yaml:"signing_key_path" mapstructure:"signing_key_path"` -} - -// LoadAgents reads agent targets from config and returns a map of AgentConfig. -// Returns an empty map (not an error) if no agents are configured. -func LoadAgents(cfg *config.Config) (map[string]AgentConfig, error) { - var agents map[string]AgentConfig - if err := cfg.Get("agentci.agents", &agents); err != nil { - return map[string]AgentConfig{}, nil - } - - // Validate and apply defaults. - for name, ac := range agents { - if !ac.Active { - continue - } - if ac.Host == "" { - return nil, coreerr.E("agentci.LoadAgents", "agent "+name+": host is required", nil) - } - if ac.QueueDir == "" { - ac.QueueDir = filepath.Join(agentic.CoreRoot(), "queue") - } - if ac.Model == "" { - ac.Model = "sonnet" - } - if ac.Runner == "" { - ac.Runner = "claude" - } - agents[name] = ac - } - - return agents, nil -} - -// LoadActiveAgents returns only active agents. -func LoadActiveAgents(cfg *config.Config) (map[string]AgentConfig, error) { - active, err := LoadAgents(cfg) - if err != nil { - return nil, err - } - maps.DeleteFunc(active, func(_ string, ac AgentConfig) bool { - return !ac.Active - }) - return active, nil -} - -// LoadClothoConfig loads the Clotho orchestrator settings. -// Returns sensible defaults if no config is present. -func LoadClothoConfig(cfg *config.Config) (ClothoConfig, error) { - var cc ClothoConfig - if err := cfg.Get("agentci.clotho", &cc); err != nil { - return ClothoConfig{ - Strategy: "direct", - ValidationThreshold: 0.85, - }, nil - } - if cc.Strategy == "" { - cc.Strategy = "direct" - } - if cc.ValidationThreshold == 0 { - cc.ValidationThreshold = 0.85 - } - return cc, nil -} - -// SaveAgent writes an agent config entry to the config file. -func SaveAgent(cfg *config.Config, name string, ac AgentConfig) error { - key := "agentci.agents." + name - data := map[string]any{ - "host": ac.Host, - "queue_dir": ac.QueueDir, - "forgejo_user": ac.ForgejoUser, - "active": ac.Active, - "dual_run": ac.DualRun, - } - if ac.Model != "" { - data["model"] = ac.Model - } - if ac.Runner != "" { - data["runner"] = ac.Runner - } - if ac.VerifyModel != "" { - data["verify_model"] = ac.VerifyModel - } - if ac.SecurityLevel != "" { - data["security_level"] = ac.SecurityLevel - } - if len(ac.Roles) > 0 { - data["roles"] = ac.Roles - } - return cfg.Set(key, data) -} - -// RemoveAgent removes an agent from the config file. -func RemoveAgent(cfg *config.Config, name string) error { - var agents map[string]AgentConfig - if err := cfg.Get("agentci.agents", &agents); err != nil { - return coreerr.E("agentci.RemoveAgent", "no agents configured", nil) - } - if _, ok := agents[name]; !ok { - return coreerr.E("agentci.RemoveAgent", "agent "+name+" not found", nil) - } - delete(agents, name) - return cfg.Set("agentci.agents", agents) -} - -// ListAgents returns all configured agents (active and inactive). -func ListAgents(cfg *config.Config) (map[string]AgentConfig, error) { - var agents map[string]AgentConfig - if err := cfg.Get("agentci.agents", &agents); err != nil { - return map[string]AgentConfig{}, nil - } - return agents, nil -} diff --git a/pkg/orchestrator/config_test.go b/pkg/orchestrator/config_test.go deleted file mode 100644 index 6ac5e44..0000000 --- a/pkg/orchestrator/config_test.go +++ /dev/null @@ -1,329 +0,0 @@ -package orchestrator - -import ( - "testing" - - "forge.lthn.ai/core/config" - "forge.lthn.ai/core/go-io" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func newTestConfig(t *testing.T, yaml string) *config.Config { - t.Helper() - m := io.NewMockMedium() - if yaml != "" { - m.Files["/tmp/test/config.yaml"] = yaml - } - cfg, err := config.New(config.WithMedium(m), config.WithPath("/tmp/test/config.yaml")) - require.NoError(t, err) - return cfg -} - -func TestLoadAgents_Good(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - darbs-claude: - host: claude@192.168.0.201 - queue_dir: /home/claude/ai-work/queue - forgejo_user: darbs-claude - model: sonnet - runner: claude - active: true -`) - agents, err := LoadAgents(cfg) - require.NoError(t, err) - require.Len(t, agents, 1) - - agent := agents["darbs-claude"] - assert.Equal(t, "claude@192.168.0.201", agent.Host) - assert.Equal(t, "/home/claude/ai-work/queue", agent.QueueDir) - assert.Equal(t, "sonnet", agent.Model) - assert.Equal(t, "claude", agent.Runner) -} - -func TestLoadAgents_Good_MultipleAgents(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - darbs-claude: - host: claude@192.168.0.201 - queue_dir: /home/claude/ai-work/queue - active: true - local-codex: - host: localhost - queue_dir: /home/claude/ai-work/queue - runner: codex - active: true -`) - agents, err := LoadAgents(cfg) - require.NoError(t, err) - assert.Len(t, agents, 2) - assert.Contains(t, agents, "darbs-claude") - assert.Contains(t, agents, "local-codex") -} - -func TestLoadAgents_Good_SkipsInactive(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - active-agent: - host: claude@10.0.0.1 - active: true - offline-agent: - host: claude@10.0.0.2 - active: false -`) - agents, err := LoadAgents(cfg) - require.NoError(t, err) - // Both are returned, but only active-agent has defaults applied. - assert.Len(t, agents, 2) - assert.Contains(t, agents, "active-agent") -} - -func TestLoadActiveAgents_Good(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - active-agent: - host: claude@10.0.0.1 - active: true - offline-agent: - host: claude@10.0.0.2 - active: false -`) - active, err := LoadActiveAgents(cfg) - require.NoError(t, err) - assert.Len(t, active, 1) - assert.Contains(t, active, "active-agent") -} - -func TestLoadAgents_Good_Defaults(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - minimal: - host: claude@10.0.0.1 - active: true -`) - agents, err := LoadAgents(cfg) - require.NoError(t, err) - require.Len(t, agents, 1) - - agent := agents["minimal"] - assert.Equal(t, "/home/claude/ai-work/queue", agent.QueueDir) - assert.Equal(t, "sonnet", agent.Model) - assert.Equal(t, "claude", agent.Runner) -} - -func TestLoadAgents_Good_NoConfig(t *testing.T) { - cfg := newTestConfig(t, "") - agents, err := LoadAgents(cfg) - require.NoError(t, err) - assert.Empty(t, agents) -} - -func TestLoadAgents_Bad_MissingHost(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - broken: - queue_dir: /tmp - active: true -`) - _, err := LoadAgents(cfg) - assert.Error(t, err) - assert.Contains(t, err.Error(), "host is required") -} - -func TestLoadAgents_Good_WithDualRun(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - gemini-agent: - host: localhost - runner: gemini - model: gemini-2.0-flash - verify_model: gemini-1.5-pro - dual_run: true - active: true -`) - agents, err := LoadAgents(cfg) - require.NoError(t, err) - - agent := agents["gemini-agent"] - assert.Equal(t, "gemini", agent.Runner) - assert.Equal(t, "gemini-2.0-flash", agent.Model) - assert.Equal(t, "gemini-1.5-pro", agent.VerifyModel) - assert.True(t, agent.DualRun) -} - -func TestLoadClothoConfig_Good(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - clotho: - strategy: clotho-verified - validation_threshold: 0.9 - signing_key_path: /etc/core/keys/clotho.pub -`) - cc, err := LoadClothoConfig(cfg) - require.NoError(t, err) - assert.Equal(t, "clotho-verified", cc.Strategy) - assert.Equal(t, 0.9, cc.ValidationThreshold) - assert.Equal(t, "/etc/core/keys/clotho.pub", cc.SigningKeyPath) -} - -func TestLoadClothoConfig_Good_Defaults(t *testing.T) { - cfg := newTestConfig(t, "") - cc, err := LoadClothoConfig(cfg) - require.NoError(t, err) - assert.Equal(t, "direct", cc.Strategy) - assert.Equal(t, 0.85, cc.ValidationThreshold) -} - -func TestSaveAgent_Good(t *testing.T) { - cfg := newTestConfig(t, "") - - err := SaveAgent(cfg, "new-agent", AgentConfig{ - Host: "claude@10.0.0.5", - QueueDir: "/home/claude/ai-work/queue", - ForgejoUser: "new-agent", - Model: "haiku", - Runner: "claude", - Active: true, - }) - require.NoError(t, err) - - agents, err := ListAgents(cfg) - require.NoError(t, err) - require.Contains(t, agents, "new-agent") - assert.Equal(t, "claude@10.0.0.5", agents["new-agent"].Host) - assert.Equal(t, "haiku", agents["new-agent"].Model) -} - -func TestSaveAgent_Good_WithDualRun(t *testing.T) { - cfg := newTestConfig(t, "") - - err := SaveAgent(cfg, "verified-agent", AgentConfig{ - Host: "claude@10.0.0.5", - Model: "gemini-2.0-flash", - VerifyModel: "gemini-1.5-pro", - DualRun: true, - Active: true, - }) - require.NoError(t, err) - - agents, err := ListAgents(cfg) - require.NoError(t, err) - require.Contains(t, agents, "verified-agent") - assert.True(t, agents["verified-agent"].DualRun) -} - -func TestSaveAgent_Good_OmitsEmptyOptionals(t *testing.T) { - cfg := newTestConfig(t, "") - - err := SaveAgent(cfg, "minimal", AgentConfig{ - Host: "claude@10.0.0.1", - Active: true, - }) - require.NoError(t, err) - - agents, err := ListAgents(cfg) - require.NoError(t, err) - assert.Contains(t, agents, "minimal") -} - -func TestRemoveAgent_Good(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - to-remove: - host: claude@10.0.0.1 - active: true - to-keep: - host: claude@10.0.0.2 - active: true -`) - err := RemoveAgent(cfg, "to-remove") - require.NoError(t, err) - - agents, err := ListAgents(cfg) - require.NoError(t, err) - assert.NotContains(t, agents, "to-remove") - assert.Contains(t, agents, "to-keep") -} - -func TestRemoveAgent_Bad_NotFound(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - existing: - host: claude@10.0.0.1 - active: true -`) - err := RemoveAgent(cfg, "nonexistent") - assert.Error(t, err) - assert.Contains(t, err.Error(), "not found") -} - -func TestRemoveAgent_Bad_NoAgents(t *testing.T) { - cfg := newTestConfig(t, "") - err := RemoveAgent(cfg, "anything") - assert.Error(t, err) - assert.Contains(t, err.Error(), "no agents configured") -} - -func TestListAgents_Good(t *testing.T) { - cfg := newTestConfig(t, ` -agentci: - agents: - agent-a: - host: claude@10.0.0.1 - active: true - agent-b: - host: claude@10.0.0.2 - active: false -`) - agents, err := ListAgents(cfg) - require.NoError(t, err) - assert.Len(t, agents, 2) - assert.True(t, agents["agent-a"].Active) - assert.False(t, agents["agent-b"].Active) -} - -func TestListAgents_Good_Empty(t *testing.T) { - cfg := newTestConfig(t, "") - agents, err := ListAgents(cfg) - require.NoError(t, err) - assert.Empty(t, agents) -} - -func TestRoundTrip_SaveThenLoad(t *testing.T) { - cfg := newTestConfig(t, "") - - err := SaveAgent(cfg, "alpha", AgentConfig{ - Host: "claude@alpha", - QueueDir: "/home/claude/work/queue", - ForgejoUser: "alpha-bot", - Model: "opus", - Runner: "claude", - Active: true, - }) - require.NoError(t, err) - - err = SaveAgent(cfg, "beta", AgentConfig{ - Host: "claude@beta", - ForgejoUser: "beta-bot", - Runner: "codex", - Active: true, - }) - require.NoError(t, err) - - agents, err := LoadActiveAgents(cfg) - require.NoError(t, err) - assert.Len(t, agents, 2) - assert.Equal(t, "claude@alpha", agents["alpha"].Host) - assert.Equal(t, "opus", agents["alpha"].Model) - assert.Equal(t, "codex", agents["beta"].Runner) -} diff --git a/pkg/orchestrator/security.go b/pkg/orchestrator/security.go deleted file mode 100644 index be51056..0000000 --- a/pkg/orchestrator/security.go +++ /dev/null @@ -1,58 +0,0 @@ -package orchestrator - -import ( - "context" - "os/exec" - "path/filepath" - "regexp" - "strings" - - coreerr "forge.lthn.ai/core/go-log" -) - -var safeNameRegex = regexp.MustCompile(`^[a-zA-Z0-9\-\_\.]+$`) - -// SanitizePath ensures a filename or directory name is safe and prevents path traversal. -// Returns filepath.Base of the input after validation. -func SanitizePath(input string) (string, error) { - base := filepath.Base(input) - if !safeNameRegex.MatchString(base) { - return "", coreerr.E("agentci.SanitizePath", "invalid characters in path element: "+input, nil) - } - if base == "." || base == ".." || base == "/" { - return "", coreerr.E("agentci.SanitizePath", "invalid path element: "+base, nil) - } - return base, nil -} - -// EscapeShellArg wraps a string in single quotes for safe remote shell insertion. -// Prefer exec.Command arguments over constructing shell strings where possible. -func EscapeShellArg(arg string) string { - return "'" + strings.ReplaceAll(arg, "'", "'\\''") + "'" -} - -// SecureSSHCommand creates an SSH exec.Cmd with strict host key checking and batch mode. -// Deprecated: Use SecureSSHCommandContext for context-aware cancellation. -func SecureSSHCommand(host string, remoteCmd string) *exec.Cmd { - return SecureSSHCommandContext(context.Background(), host, remoteCmd) -} - -// SecureSSHCommandContext creates an SSH exec.Cmd with context support for cancellation, -// strict host key checking, and batch mode. -func SecureSSHCommandContext(ctx context.Context, host string, remoteCmd string) *exec.Cmd { - return exec.CommandContext(ctx, "ssh", - "-o", "StrictHostKeyChecking=yes", - "-o", "BatchMode=yes", - "-o", "ConnectTimeout=10", - host, - remoteCmd, - ) -} - -// MaskToken returns a masked version of a token for safe logging. -func MaskToken(token string) string { - if len(token) < 8 { - return "*****" - } - return token[:4] + "****" + token[len(token)-4:] -} diff --git a/pkg/orchestrator/security_test.go b/pkg/orchestrator/security_test.go deleted file mode 100644 index 9844135..0000000 --- a/pkg/orchestrator/security_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package orchestrator - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSanitizePath_Good(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - {name: "simple name", input: "myfile.txt", expected: "myfile.txt"}, - {name: "with hyphen", input: "my-file", expected: "my-file"}, - {name: "with underscore", input: "my_file", expected: "my_file"}, - {name: "with dots", input: "file.tar.gz", expected: "file.tar.gz"}, - {name: "strips directory", input: "/path/to/file.txt", expected: "file.txt"}, - {name: "alphanumeric", input: "abc123", expected: "abc123"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := SanitizePath(tt.input) - require.NoError(t, err) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestSanitizePath_Good_StripsDirTraversal(t *testing.T) { - // filepath.Base("../secret") returns "secret" which is safe. - result, err := SanitizePath("../secret") - require.NoError(t, err) - assert.Equal(t, "secret", result, "directory traversal component stripped by filepath.Base") -} - -func TestSanitizePath_Bad(t *testing.T) { - tests := []struct { - name string - input string - }{ - {name: "spaces", input: "my file"}, - {name: "special chars", input: "file;rm -rf"}, - {name: "pipe", input: "file|cmd"}, - {name: "backtick", input: "file`cmd`"}, - {name: "dollar", input: "file$var"}, - {name: "single dot", input: "."}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := SanitizePath(tt.input) - assert.Error(t, err) - }) - } -} - -func TestEscapeShellArg_Good(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - {name: "simple string", input: "hello", expected: "'hello'"}, - {name: "with spaces", input: "hello world", expected: "'hello world'"}, - {name: "empty string", input: "", expected: "''"}, - {name: "with single quote", input: "it's", expected: "'it'\\''s'"}, - {name: "multiple single quotes", input: "a'b'c", expected: "'a'\\''b'\\''c'"}, - {name: "with special chars", input: "$(rm -rf /)", expected: "'$(rm -rf /)'"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := EscapeShellArg(tt.input) - assert.Equal(t, tt.expected, result) - }) - } -} - -func TestSecureSSHCommand_Good(t *testing.T) { - cmd := SecureSSHCommand("claude@10.0.0.1", "ls -la /tmp") - - assert.Equal(t, "ssh", cmd.Path[len(cmd.Path)-3:]) - args := cmd.Args - assert.Contains(t, args, "-o") - assert.Contains(t, args, "StrictHostKeyChecking=yes") - assert.Contains(t, args, "BatchMode=yes") - assert.Contains(t, args, "ConnectTimeout=10") - assert.Contains(t, args, "claude@10.0.0.1") - assert.Contains(t, args, "ls -la /tmp") -} - -func TestMaskToken_Good(t *testing.T) { - tests := []struct { - name string - token string - expected string - }{ - {name: "normal token", token: "abcdefghijkl", expected: "abcd****ijkl"}, - {name: "exactly 8 chars", token: "12345678", expected: "1234****5678"}, - {name: "short token", token: "abc", expected: "*****"}, - {name: "empty token", token: "", expected: "*****"}, - {name: "7 chars", token: "1234567", expected: "*****"}, - {name: "long token", token: "ghp_1234567890abcdef", expected: "ghp_****cdef"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := MaskToken(tt.token) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/pkg/plugin/contract_test.go b/pkg/plugin/contract_test.go deleted file mode 100644 index b2975b9..0000000 --- a/pkg/plugin/contract_test.go +++ /dev/null @@ -1,488 +0,0 @@ -// Package plugin verifies the Claude Code plugin contract. Every plugin in the -// marketplace must satisfy the structural rules Claude Code expects: valid JSON -// manifests, commands with YAML frontmatter, executable scripts, and well-formed -// hooks. These tests catch breakage before a tag ships. -package plugin - -import ( - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// ── types ────────────────────────────────────────────────────────── - -type marketplace struct { - Name string `json:"name"` - Description string `json:"description"` - Owner owner `json:"owner"` - Plugins []plugin `json:"plugins"` -} - -type owner struct { - Name string `json:"name"` - Email string `json:"email"` -} - -type plugin struct { - Name string `json:"name"` - Source string `json:"source"` - Description string `json:"description"` - Version string `json:"version"` -} - -type pluginManifest struct { - Name string `json:"name"` - Description string `json:"description"` - Version string `json:"version"` -} - -type hooksFile struct { - Schema string `json:"$schema"` - Hooks map[string]json.RawMessage `json:"hooks"` -} - -type hookEntry struct { - Matcher string `json:"matcher"` - Hooks []hookDef `json:"hooks"` - Description string `json:"description"` -} - -type hookDef struct { - Type string `json:"type"` - Command string `json:"command"` -} - -// ── helpers ──────────────────────────────────────────────────────── - -func repoRoot(t *testing.T) string { - t.Helper() - dir, err := os.Getwd() - require.NoError(t, err) - for { - if _, err := os.Stat(filepath.Join(dir, ".claude-plugin", "marketplace.json")); err == nil { - return dir - } - parent := filepath.Dir(dir) - require.NotEqual(t, parent, dir, "marketplace.json not found") - dir = parent - } -} - -func loadMarketplace(t *testing.T) (marketplace, string) { - t.Helper() - root := repoRoot(t) - data, err := os.ReadFile(filepath.Join(root, ".claude-plugin", "marketplace.json")) - require.NoError(t, err) - var mp marketplace - require.NoError(t, json.Unmarshal(data, &mp)) - return mp, root -} - -// validHookEvents are the hook events Claude Code supports. -var validHookEvents = map[string]bool{ - "PreToolUse": true, - "PostToolUse": true, - "Stop": true, - "SubagentStop": true, - "SessionStart": true, - "SessionEnd": true, - "UserPromptSubmit": true, - "PreCompact": true, - "Notification": true, -} - -// ── Marketplace contract ─────────────────────────────────────────── - -func TestMarketplace_Valid(t *testing.T) { - mp, _ := loadMarketplace(t) - assert.NotEmpty(t, mp.Name, "marketplace must have a name") - assert.NotEmpty(t, mp.Description, "marketplace must have a description") - assert.NotEmpty(t, mp.Owner.Name, "marketplace must have an owner name") - assert.NotEmpty(t, mp.Plugins, "marketplace must list at least one plugin") -} - -func TestMarketplace_PluginsHaveRequiredFields(t *testing.T) { - mp, _ := loadMarketplace(t) - for _, p := range mp.Plugins { - assert.NotEmpty(t, p.Name, "plugin must have a name") - assert.NotEmpty(t, p.Source, "plugin %s must have a source path", p.Name) - assert.NotEmpty(t, p.Description, "plugin %s must have a description", p.Name) - assert.NotEmpty(t, p.Version, "plugin %s must have a version", p.Name) - } -} - -func TestMarketplace_UniquePluginNames(t *testing.T) { - mp, _ := loadMarketplace(t) - seen := map[string]bool{} - for _, p := range mp.Plugins { - assert.False(t, seen[p.Name], "duplicate plugin name: %s", p.Name) - seen[p.Name] = true - } -} - -// ── Plugin directory structure ───────────────────────────────────── - -func TestPlugin_DirectoryExists(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - pluginDir := filepath.Join(root, p.Source) - info, err := os.Stat(pluginDir) - require.NoError(t, err, "plugin %s: source dir %s must exist", p.Name, p.Source) - assert.True(t, info.IsDir(), "plugin %s: source must be a directory", p.Name) - } -} - -func TestPlugin_HasManifest(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - manifestPath := filepath.Join(root, p.Source, ".claude-plugin", "plugin.json") - data, err := os.ReadFile(manifestPath) - require.NoError(t, err, "plugin %s: must have .claude-plugin/plugin.json", p.Name) - - var manifest pluginManifest - require.NoError(t, json.Unmarshal(data, &manifest), "plugin %s: invalid plugin.json", p.Name) - assert.NotEmpty(t, manifest.Name, "plugin %s: manifest must have a name", p.Name) - assert.NotEmpty(t, manifest.Description, "plugin %s: manifest must have a description", p.Name) - assert.NotEmpty(t, manifest.Version, "plugin %s: manifest must have a version", p.Name) - } -} - -// ── Commands contract ────────────────────────────────────────────── - -func TestPlugin_CommandsAreMarkdown(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - cmdDir := filepath.Join(root, p.Source, "commands") - entries, err := os.ReadDir(cmdDir) - if os.IsNotExist(err) { - continue // commands dir is optional - } - require.NoError(t, err, "plugin %s: failed to read commands dir", p.Name) - - for _, entry := range entries { - if entry.IsDir() { - continue - } - assert.True(t, strings.HasSuffix(entry.Name(), ".md"), - "plugin %s: command %s must be a .md file", p.Name, entry.Name()) - } - } -} - -func TestPlugin_CommandsHaveFrontmatter(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - cmdDir := filepath.Join(root, p.Source, "commands") - entries, err := os.ReadDir(cmdDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".md") { - continue - } - data, err := os.ReadFile(filepath.Join(cmdDir, entry.Name())) - require.NoError(t, err) - - content := string(data) - assert.True(t, strings.HasPrefix(content, "---"), - "plugin %s: command %s must start with YAML frontmatter (---)", p.Name, entry.Name()) - - // Must have closing frontmatter - parts := strings.SplitN(content[3:], "---", 2) - assert.True(t, len(parts) >= 2, - "plugin %s: command %s must have closing frontmatter (---)", p.Name, entry.Name()) - - // Frontmatter must contain name: - assert.Contains(t, parts[0], "name:", - "plugin %s: command %s frontmatter must contain 'name:'", p.Name, entry.Name()) - } - } -} - -// ── Hooks contract ───────────────────────────────────────────────── - -func TestPlugin_HooksFileValid(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - hooksPath := filepath.Join(root, p.Source, "hooks.json") - data, err := os.ReadFile(hooksPath) - if os.IsNotExist(err) { - continue // hooks.json is optional - } - require.NoError(t, err, "plugin %s: failed to read hooks.json", p.Name) - - var hf hooksFile - require.NoError(t, json.Unmarshal(data, &hf), "plugin %s: invalid hooks.json", p.Name) - assert.NotEmpty(t, hf.Hooks, "plugin %s: hooks.json must define at least one event", p.Name) - } -} - -func TestPlugin_HooksUseValidEvents(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - hooksPath := filepath.Join(root, p.Source, "hooks.json") - data, err := os.ReadFile(hooksPath) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - var hf hooksFile - require.NoError(t, json.Unmarshal(data, &hf)) - - for event := range hf.Hooks { - assert.True(t, validHookEvents[event], - "plugin %s: unknown hook event %q (valid: %v)", p.Name, event, validHookEvents) - } - } -} - -func TestPlugin_HookScriptsExist(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - hooksPath := filepath.Join(root, p.Source, "hooks.json") - data, err := os.ReadFile(hooksPath) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - var hf hooksFile - require.NoError(t, json.Unmarshal(data, &hf)) - - pluginRoot := filepath.Join(root, p.Source) - - for event, raw := range hf.Hooks { - var entries []hookEntry - require.NoError(t, json.Unmarshal(raw, &entries), - "plugin %s: failed to parse %s entries", p.Name, event) - - for _, entry := range entries { - for _, h := range entry.Hooks { - if h.Type != "command" { - continue - } - // Resolve ${CLAUDE_PLUGIN_ROOT} to the plugin source directory - cmd := strings.ReplaceAll(h.Command, "${CLAUDE_PLUGIN_ROOT}", pluginRoot) - // Extract the script path (first arg, before any flags) - scriptPath := strings.Fields(cmd)[0] - _, err := os.Stat(scriptPath) - assert.NoError(t, err, - "plugin %s: hook script %s does not exist (event: %s)", p.Name, h.Command, event) - } - } - } - } -} - -func TestPlugin_HookScriptsExecutable(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - hooksPath := filepath.Join(root, p.Source, "hooks.json") - data, err := os.ReadFile(hooksPath) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - var hf hooksFile - require.NoError(t, json.Unmarshal(data, &hf)) - - pluginRoot := filepath.Join(root, p.Source) - - for event, raw := range hf.Hooks { - var entries []hookEntry - require.NoError(t, json.Unmarshal(raw, &entries)) - - for _, entry := range entries { - for _, h := range entry.Hooks { - if h.Type != "command" { - continue - } - cmd := strings.ReplaceAll(h.Command, "${CLAUDE_PLUGIN_ROOT}", pluginRoot) - scriptPath := strings.Fields(cmd)[0] - info, err := os.Stat(scriptPath) - if err != nil { - continue // Already caught by ScriptsExist test - } - assert.NotZero(t, info.Mode()&0111, - "plugin %s: hook script %s must be executable (event: %s)", p.Name, h.Command, event) - } - } - } - } -} - -// ── Scripts contract ─────────────────────────────────────────────── - -func TestPlugin_AllScriptsExecutable(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - scriptsDir := filepath.Join(root, p.Source, "scripts") - entries, err := os.ReadDir(scriptsDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if entry.IsDir() { - continue - } - if !strings.HasSuffix(entry.Name(), ".sh") { - continue - } - info, err := entry.Info() - require.NoError(t, err) - assert.NotZero(t, info.Mode()&0111, - "plugin %s: script %s must be executable", p.Name, entry.Name()) - } - } -} - -func TestPlugin_ScriptsHaveShebang(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - scriptsDir := filepath.Join(root, p.Source, "scripts") - entries, err := os.ReadDir(scriptsDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sh") { - continue - } - data, err := os.ReadFile(filepath.Join(scriptsDir, entry.Name())) - require.NoError(t, err) - assert.True(t, strings.HasPrefix(string(data), "#!"), - "plugin %s: script %s must start with a shebang (#!)", p.Name, entry.Name()) - } - } -} - -// ── Skills contract ──────────────────────────────────────────────── - -func TestPlugin_SkillsHaveSkillMd(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - skillsDir := filepath.Join(root, p.Source, "skills") - entries, err := os.ReadDir(skillsDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - skillMd := filepath.Join(skillsDir, entry.Name(), "SKILL.md") - _, err := os.Stat(skillMd) - assert.NoError(t, err, - "plugin %s: skill %s must have a SKILL.md", p.Name, entry.Name()) - } - } -} - -func TestPlugin_SkillScriptsExecutable(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - skillsDir := filepath.Join(root, p.Source, "skills") - entries, err := os.ReadDir(skillsDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - skillDir := filepath.Join(skillsDir, entry.Name()) - scripts, _ := os.ReadDir(skillDir) - for _, s := range scripts { - if s.IsDir() || !strings.HasSuffix(s.Name(), ".sh") { - continue - } - info, err := s.Info() - require.NoError(t, err) - assert.NotZero(t, info.Mode()&0111, - "plugin %s: skill script %s/%s must be executable", p.Name, entry.Name(), s.Name()) - } - } - } -} - -// ── Cross-references ─────────────────────────────────────────────── - -func TestPlugin_CollectionScriptsExecutable(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - collDir := filepath.Join(root, p.Source, "collection") - entries, err := os.ReadDir(collDir) - if os.IsNotExist(err) { - continue - } - require.NoError(t, err) - - for _, entry := range entries { - if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sh") { - continue - } - info, err := entry.Info() - require.NoError(t, err) - assert.NotZero(t, info.Mode()&0111, - "plugin %s: collection script %s must be executable", p.Name, entry.Name()) - } - } -} - -func TestMarketplace_SourcesMatchDirectories(t *testing.T) { - mp, root := loadMarketplace(t) - - // Every directory in claude/ should be listed in marketplace - claudeDir := filepath.Join(root, "claude") - entries, err := os.ReadDir(claudeDir) - require.NoError(t, err) - - pluginNames := map[string]bool{} - for _, p := range mp.Plugins { - pluginNames[p.Name] = true - } - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - assert.True(t, pluginNames[entry.Name()], - "directory claude/%s exists but is not listed in marketplace.json", entry.Name()) - } -} - -func TestMarketplace_VersionConsistency(t *testing.T) { - mp, root := loadMarketplace(t) - for _, p := range mp.Plugins { - manifestPath := filepath.Join(root, p.Source, ".claude-plugin", "plugin.json") - data, err := os.ReadFile(manifestPath) - if err != nil { - continue // Already caught by HasManifest test - } - var manifest pluginManifest - if err := json.Unmarshal(data, &manifest); err != nil { - continue - } - assert.Equal(t, p.Version, manifest.Version, - "plugin %s: marketplace version %q != manifest version %q", p.Name, p.Version, manifest.Version) - } -} diff --git a/pkg/workspace/contract_test.go b/pkg/workspace/contract_test.go deleted file mode 100644 index 3bcd274..0000000 --- a/pkg/workspace/contract_test.go +++ /dev/null @@ -1,270 +0,0 @@ -// Package workspace verifies the workspace contract defined by the original -// php-devops wishlist is fully implemented. This test loads the real repos.yaml -// shipped with core/agent and validates every aspect of the specification. -package workspace - -import ( - "os" - "path/filepath" - "testing" - - "forge.lthn.ai/core/go-io" - "forge.lthn.ai/core/go-scm/repos" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gopkg.in/yaml.v3" -) - -// repoRoot returns the absolute path to the core/agent repo root. -func repoRoot(t *testing.T) string { - t.Helper() - // Walk up from this test file to find repos.yaml. - dir, err := os.Getwd() - require.NoError(t, err) - for { - if _, err := os.Stat(filepath.Join(dir, "repos.yaml")); err == nil { - return dir - } - parent := filepath.Dir(dir) - require.NotEqual(t, parent, dir, "repos.yaml not found") - dir = parent - } -} - -// ── repos.yaml contract ──────────────────────────────────────────── - -func TestContract_ReposYAML_Loads(t *testing.T) { - root := repoRoot(t) - path := filepath.Join(root, "repos.yaml") - - reg, err := repos.LoadRegistry(io.Local, path) - require.NoError(t, err) - require.NotNil(t, reg) - - assert.Equal(t, 1, reg.Version, "repos.yaml must declare version: 1") - assert.NotEmpty(t, reg.Org, "repos.yaml must declare an org") - assert.NotEmpty(t, reg.BasePath, "repos.yaml must declare base_path") -} - -func TestContract_ReposYAML_HasRequiredFields(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - require.NotEmpty(t, reg.Repos, "repos.yaml must define at least one repo") - - for name, repo := range reg.Repos { - assert.NotEmpty(t, repo.Type, "%s: must have a type", name) - assert.NotEmpty(t, repo.Description, "%s: must have a description", name) - } -} - -func TestContract_ReposYAML_ValidTypes(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - validTypes := map[string]bool{ - "foundation": true, - "module": true, - "product": true, - "template": true, - "meta": true, - } - - for name, repo := range reg.Repos { - assert.True(t, validTypes[repo.Type], "%s: invalid type %q", name, repo.Type) - } -} - -func TestContract_ReposYAML_DependenciesExist(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - for name, repo := range reg.Repos { - for _, dep := range repo.DependsOn { - _, ok := reg.Get(dep) - assert.True(t, ok, "%s: depends on %q which is not in repos.yaml", name, dep) - } - } -} - -func TestContract_ReposYAML_TopologicalOrder(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - order, err := reg.TopologicalOrder() - require.NoError(t, err, "dependency graph must be acyclic") - assert.Equal(t, len(reg.Repos), len(order), "topological order must include all repos") - - // Verify ordering: every dependency appears before its dependant. - seen := map[string]bool{} - for _, repo := range order { - for _, dep := range repo.DependsOn { - assert.True(t, seen[dep], "%s appears before its dependency %s", repo.Name, dep) - } - seen[repo.Name] = true - } -} - -func TestContract_ReposYAML_HasFoundation(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - foundations := reg.ByType("foundation") - assert.NotEmpty(t, foundations, "repos.yaml must have at least one foundation package") -} - -func TestContract_ReposYAML_Defaults(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - assert.NotEmpty(t, reg.Defaults.Branch, "defaults must specify a branch") - assert.NotEmpty(t, reg.Defaults.License, "defaults must specify a licence") -} - -func TestContract_ReposYAML_MetaDoesNotCloneSelf(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - for name, repo := range reg.Repos { - if repo.Type == "meta" && repo.Clone != nil && !*repo.Clone { - // Meta repos with clone: false are correct. - continue - } - if repo.Type == "meta" { - t.Logf("%s: meta repo should set clone: false", name) - } - } -} - -func TestContract_ReposYAML_ProductsHaveDomain(t *testing.T) { - root := repoRoot(t) - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - for name, repo := range reg.Repos { - if repo.Type == "product" && repo.Domain != "" { - // Products with domains are properly configured. - assert.Contains(t, repo.Domain, ".", "%s: domain should be a valid hostname", name) - } - } -} - -// ── workspace.yaml contract ──────────────────────────────────────── - -type workspaceConfig struct { - Version int `yaml:"version"` - Active string `yaml:"active"` - DefaultOnly []string `yaml:"default_only"` - PackagesDir string `yaml:"packages_dir"` - Settings map[string]any `yaml:"settings"` -} - -func TestContract_WorkspaceYAML_Loads(t *testing.T) { - root := repoRoot(t) - path := filepath.Join(root, ".core", "workspace.yaml") - - data, err := os.ReadFile(path) - require.NoError(t, err, ".core/workspace.yaml must exist") - - var ws workspaceConfig - require.NoError(t, yaml.Unmarshal(data, &ws)) - - assert.Equal(t, 1, ws.Version, "workspace.yaml must declare version: 1") - assert.NotEmpty(t, ws.Active, "workspace.yaml must declare an active package") - assert.NotEmpty(t, ws.PackagesDir, "workspace.yaml must declare packages_dir") -} - -func TestContract_WorkspaceYAML_ActiveInRegistry(t *testing.T) { - root := repoRoot(t) - - // Load workspace config. - data, err := os.ReadFile(filepath.Join(root, ".core", "workspace.yaml")) - require.NoError(t, err) - var ws workspaceConfig - require.NoError(t, yaml.Unmarshal(data, &ws)) - - // Load repos registry. - reg, err := repos.LoadRegistry(io.Local, filepath.Join(root, "repos.yaml")) - require.NoError(t, err) - - _, ok := reg.Get(ws.Active) - assert.True(t, ok, "workspace.yaml active package %q must exist in repos.yaml", ws.Active) -} - -// ── .core/ folder spec contract ──────────────────────────────────── - -func TestContract_CoreFolder_Exists(t *testing.T) { - root := repoRoot(t) - info, err := os.Stat(filepath.Join(root, ".core")) - require.NoError(t, err, ".core/ directory must exist") - assert.True(t, info.IsDir()) -} - -func TestContract_CoreFolder_HasSpec(t *testing.T) { - root := repoRoot(t) - _, err := os.Stat(filepath.Join(root, ".core", "docs", "core-folder-spec.md")) - assert.NoError(t, err, ".core/docs/core-folder-spec.md must exist") -} - -// ── Setup scripts contract ───────────────────────────────────────── - -func TestContract_SetupScript_Exists(t *testing.T) { - root := repoRoot(t) - _, err := os.Stat(filepath.Join(root, "setup.sh")) - assert.NoError(t, err, "setup.sh must exist at repo root") -} - -func TestContract_SetupScript_Executable(t *testing.T) { - root := repoRoot(t) - info, err := os.Stat(filepath.Join(root, "setup.sh")) - if err != nil { - t.Skip("setup.sh not found") - } - assert.NotZero(t, info.Mode()&0111, "setup.sh must be executable") -} - -func TestContract_InstallScripts_Exist(t *testing.T) { - root := repoRoot(t) - scripts := []string{ - "scripts/install-deps.sh", - "scripts/install-core.sh", - } - for _, s := range scripts { - _, err := os.Stat(filepath.Join(root, s)) - assert.NoError(t, err, "%s must exist", s) - } -} - -// ── Claude plugins contract ──────────────────────────────────────── - -func TestContract_Marketplace_Exists(t *testing.T) { - root := repoRoot(t) - _, err := os.Stat(filepath.Join(root, ".claude-plugin", "marketplace.json")) - assert.NoError(t, err, ".claude-plugin/marketplace.json must exist for plugin distribution") -} - -func TestContract_Plugins_HaveManifests(t *testing.T) { - root := repoRoot(t) - pluginDir := filepath.Join(root, "claude") - - entries, err := os.ReadDir(pluginDir) - if err != nil { - t.Skip("claude/ directory not found") - } - - for _, entry := range entries { - if !entry.IsDir() { - continue - } - manifest := filepath.Join(pluginDir, entry.Name(), ".claude-plugin", "plugin.json") - _, err := os.Stat(manifest) - assert.NoError(t, err, "claude/%s must have .claude-plugin/plugin.json", entry.Name()) - } -}