From 3ccb67bddd7f1c54183ee1a8b065d0a8836a351f Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Feb 2026 02:59:17 +0000 Subject: [PATCH] feat(agentci): rate limiting and native Go dispatch runner MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds pkg/ratelimit for Gemini API rate limiting with sliding window (RPM/TPM/RPD), persistent state, and token counting. Replaces the bash agent-runner.sh with a native Go implementation under `core ai dispatch {run,watch,status}` for local queue processing. Rate limiting: - Per-model quotas (RPM, TPM, RPD) with 1-minute sliding window - WaitForCapacity blocks until capacity available or context cancelled - Persistent state in ~/.core/ratelimits.yaml - Default quotas for Gemini 3 Pro/Flash, 2.5 Pro, 2.0 Flash/Lite - CountTokens helper calls Google tokenizer API - CLI: core ai ratelimits {show,reset,count,config,check} Dispatch runner: - core ai dispatch run — process single ticket from queue - core ai dispatch watch — daemon mode with configurable interval - core ai dispatch status — show queue/active/done counts - Supports claude/codex/gemini runners with rate-limited Gemini - File-based locking with stale PID detection - Completion handler updates issue labels on success/failure Closes #42 Co-Authored-By: Claude Opus 4.6 --- internal/cmd/ai/cmd_commands.go | 6 + internal/cmd/ai/cmd_dispatch.go | 498 ++++++++++++++++++++++++++ internal/cmd/ai/cmd_ratelimits.go | 213 +++++++++++ internal/cmd/ai/ratelimit_dispatch.go | 49 +++ pkg/jobrunner/handlers/completion.go | 87 +++++ pkg/ratelimit/ratelimit.go | 382 ++++++++++++++++++++ pkg/ratelimit/ratelimit_test.go | 176 +++++++++ 7 files changed, 1411 insertions(+) create mode 100644 internal/cmd/ai/cmd_dispatch.go create mode 100644 internal/cmd/ai/cmd_ratelimits.go create mode 100644 internal/cmd/ai/ratelimit_dispatch.go create mode 100644 pkg/jobrunner/handlers/completion.go create mode 100644 pkg/ratelimit/ratelimit.go create mode 100644 pkg/ratelimit/ratelimit_test.go diff --git a/internal/cmd/ai/cmd_commands.go b/internal/cmd/ai/cmd_commands.go index 68c3162..5679c57 100644 --- a/internal/cmd/ai/cmd_commands.go +++ b/internal/cmd/ai/cmd_commands.go @@ -69,6 +69,12 @@ func initCommands() { // Add agent management commands (core ai agent ...) AddAgentCommands(aiCmd) + + // Add rate limit management commands (core ai ratelimits ...) + AddRateLimitCommands(aiCmd) + + // Add dispatch commands (core ai dispatch run/watch/status) + AddDispatchCommands(aiCmd) } // AddAICommands registers the 'ai' command and all subcommands. diff --git a/internal/cmd/ai/cmd_dispatch.go b/internal/cmd/ai/cmd_dispatch.go new file mode 100644 index 0000000..dc0d74d --- /dev/null +++ b/internal/cmd/ai/cmd_dispatch.go @@ -0,0 +1,498 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "sort" + "strconv" + "strings" + "syscall" + "time" + + "github.com/host-uk/core/pkg/cli" + "github.com/host-uk/core/pkg/log" +) + +// AddDispatchCommands registers the 'dispatch' subcommand group under 'ai'. +// These commands run ON the agent machine to process the work queue. +func AddDispatchCommands(parent *cli.Command) { + dispatchCmd := &cli.Command{ + Use: "dispatch", + Short: "Agent work queue processor (runs on agent machine)", + } + + dispatchCmd.AddCommand(dispatchRunCmd()) + dispatchCmd.AddCommand(dispatchWatchCmd()) + dispatchCmd.AddCommand(dispatchStatusCmd()) + + parent.AddCommand(dispatchCmd) +} + +// dispatchTicket represents the work item JSON structure. +type dispatchTicket struct { + ID string `json:"id"` + RepoOwner string `json:"repo_owner"` + RepoName string `json:"repo_name"` + IssueNumber int `json:"issue_number"` + IssueTitle string `json:"issue_title"` + IssueBody string `json:"issue_body"` + TargetBranch string `json:"target_branch"` + EpicNumber int `json:"epic_number"` + ForgeURL string `json:"forge_url"` + ForgeToken string `json:"forge_token"` + ForgeUser string `json:"forgejo_user"` + Model string `json:"model"` + Runner string `json:"runner"` + Timeout string `json:"timeout"` + CreatedAt string `json:"created_at"` +} + +const ( + defaultWorkDir = "ai-work" + lockFileName = ".runner.lock" +) + +type runnerPaths struct { + root string + queue string + active string + done string + logs string + jobs string + lock string +} + +func getPaths(baseDir string) runnerPaths { + if baseDir == "" { + home, _ := os.UserHomeDir() + baseDir = filepath.Join(home, defaultWorkDir) + } + return runnerPaths{ + root: baseDir, + queue: filepath.Join(baseDir, "queue"), + active: filepath.Join(baseDir, "active"), + done: filepath.Join(baseDir, "done"), + logs: filepath.Join(baseDir, "logs"), + jobs: filepath.Join(baseDir, "jobs"), + lock: filepath.Join(baseDir, lockFileName), + } +} + +func dispatchRunCmd() *cli.Command { + cmd := &cli.Command{ + Use: "run", + Short: "Process a single ticket from the queue", + RunE: func(cmd *cli.Command, args []string) error { + workDir, _ := cmd.Flags().GetString("work-dir") + paths := getPaths(workDir) + + if err := ensureDispatchDirs(paths); err != nil { + return err + } + + if err := acquireLock(paths.lock); err != nil { + log.Info("Runner locked, skipping run", "lock", paths.lock) + return nil + } + defer releaseLock(paths.lock) + + ticketFile, err := pickOldestTicket(paths.queue) + if err != nil { + return err + } + if ticketFile == "" { + return nil + } + + return processTicket(paths, ticketFile) + }, + } + cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") + return cmd +} + +func dispatchWatchCmd() *cli.Command { + cmd := &cli.Command{ + Use: "watch", + Short: "Run as a daemon, polling the queue", + RunE: func(cmd *cli.Command, args []string) error { + workDir, _ := cmd.Flags().GetString("work-dir") + interval, _ := cmd.Flags().GetDuration("interval") + paths := getPaths(workDir) + + if err := ensureDispatchDirs(paths); err != nil { + return err + } + + log.Info("Starting dispatch watcher", "dir", paths.root, "interval", interval) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + runCycle(paths) + + for { + select { + case <-ticker.C: + runCycle(paths) + case <-sigChan: + log.Info("Shutting down watcher...") + return nil + case <-ctx.Done(): + return nil + } + } + }, + } + cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") + cmd.Flags().Duration("interval", 5*time.Minute, "Polling interval") + return cmd +} + +func dispatchStatusCmd() *cli.Command { + cmd := &cli.Command{ + Use: "status", + Short: "Show runner status", + RunE: func(cmd *cli.Command, args []string) error { + workDir, _ := cmd.Flags().GetString("work-dir") + paths := getPaths(workDir) + + lockStatus := "IDLE" + if data, err := os.ReadFile(paths.lock); err == nil { + pidStr := strings.TrimSpace(string(data)) + pid, _ := strconv.Atoi(pidStr) + if isProcessAlive(pid) { + lockStatus = fmt.Sprintf("RUNNING (PID %d)", pid) + } else { + lockStatus = fmt.Sprintf("STALE (PID %d)", pid) + } + } + + countFiles := func(dir string) int { + entries, _ := os.ReadDir(dir) + count := 0 + for _, e := range entries { + if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") { + count++ + } + } + return count + } + + fmt.Println("=== Agent Dispatch Status ===") + fmt.Printf("Work Dir: %s\n", paths.root) + fmt.Printf("Status: %s\n", lockStatus) + fmt.Printf("Queue: %d\n", countFiles(paths.queue)) + fmt.Printf("Active: %d\n", countFiles(paths.active)) + fmt.Printf("Done: %d\n", countFiles(paths.done)) + + return nil + }, + } + cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)") + return cmd +} + +func runCycle(paths runnerPaths) { + if err := acquireLock(paths.lock); err != nil { + log.Debug("Runner locked, skipping cycle") + return + } + defer releaseLock(paths.lock) + + ticketFile, err := pickOldestTicket(paths.queue) + if err != nil { + log.Error("Failed to pick ticket", "error", err) + return + } + if ticketFile == "" { + return + } + + if err := processTicket(paths, ticketFile); err != nil { + log.Error("Failed to process ticket", "file", ticketFile, "error", err) + } +} + +func processTicket(paths runnerPaths, ticketPath string) error { + fileName := filepath.Base(ticketPath) + log.Info("Processing ticket", "file", fileName) + + activePath := filepath.Join(paths.active, fileName) + if err := os.Rename(ticketPath, activePath); err != nil { + return fmt.Errorf("failed to move ticket to active: %w", err) + } + + data, err := os.ReadFile(activePath) + if err != nil { + return fmt.Errorf("failed to read ticket: %w", err) + } + var t dispatchTicket + if err := json.Unmarshal(data, &t); err != nil { + return fmt.Errorf("failed to unmarshal ticket: %w", err) + } + + jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber)) + repoDir := filepath.Join(jobDir, t.RepoName) + if err := os.MkdirAll(jobDir, 0755); err != nil { + return err + } + + if err := prepareRepo(t, repoDir); err != nil { + reportToForge(t, false, fmt.Sprintf("Git setup failed: %v", err)) + moveToDone(paths, activePath, fileName) + return err + } + + prompt := buildPrompt(t) + + logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s-%d.log", t.RepoOwner, t.RepoName, t.IssueNumber)) + success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile) + + msg := fmt.Sprintf("Agent completed work on #%d. Exit code: %d.", t.IssueNumber, exitCode) + if !success { + msg = fmt.Sprintf("Agent failed on #%d (exit code: %d). Check logs on agent machine.", t.IssueNumber, exitCode) + if runErr != nil { + msg += fmt.Sprintf(" Error: %v", runErr) + } + } + reportToForge(t, success, msg) + + moveToDone(paths, activePath, fileName) + log.Info("Ticket complete", "id", t.ID, "success", success) + return nil +} + +func prepareRepo(t dispatchTicket, repoDir string) error { + user := t.ForgeUser + if user == "" { + host, _ := os.Hostname() + user = fmt.Sprintf("%s-%s", host, os.Getenv("USER")) + } + + cleanURL := strings.TrimPrefix(t.ForgeURL, "https://") + cleanURL = strings.TrimPrefix(cleanURL, "http://") + cloneURL := fmt.Sprintf("https://%s:%s@%s/%s/%s.git", user, t.ForgeToken, cleanURL, t.RepoOwner, t.RepoName) + + if _, err := os.Stat(filepath.Join(repoDir, ".git")); err == nil { + log.Info("Updating existing repo", "dir", repoDir) + cmds := [][]string{ + {"git", "fetch", "origin"}, + {"git", "checkout", t.TargetBranch}, + {"git", "pull", "origin", t.TargetBranch}, + } + for _, args := range cmds { + cmd := exec.Command(args[0], args[1:]...) + cmd.Dir = repoDir + if out, err := cmd.CombinedOutput(); err != nil { + if args[1] == "checkout" { + createCmd := exec.Command("git", "checkout", "-b", t.TargetBranch, "origin/"+t.TargetBranch) + createCmd.Dir = repoDir + if _, err2 := createCmd.CombinedOutput(); err2 == nil { + continue + } + } + return fmt.Errorf("git command %v failed: %s", args, string(out)) + } + } + } else { + log.Info("Cloning repo", "url", t.RepoOwner+"/"+t.RepoName) + cmd := exec.Command("git", "clone", "-b", t.TargetBranch, cloneURL, repoDir) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("git clone failed: %s", string(out)) + } + } + return nil +} + +func buildPrompt(t dispatchTicket) string { + return fmt.Sprintf(`You are working on issue #%d in %s/%s. + +Title: %s + +Description: +%s + +The repo is cloned at the current directory on branch '%s'. +Create a feature branch from '%s', make minimal targeted changes, commit referencing #%d, and push. +Then create a PR targeting '%s' using the forgejo MCP tools or git push.`, + t.IssueNumber, t.RepoOwner, t.RepoName, + t.IssueTitle, + t.IssueBody, + t.TargetBranch, + t.TargetBranch, t.IssueNumber, + t.TargetBranch, + ) +} + +func runAgent(t dispatchTicket, prompt, dir, logPath string) (bool, int, error) { + timeout := 30 * time.Minute + if t.Timeout != "" { + if d, err := time.ParseDuration(t.Timeout); err == nil { + timeout = d + } + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + model := t.Model + if model == "" { + model = "sonnet" + } + + log.Info("Running agent", "runner", t.Runner, "model", model) + + // For Gemini runner, wrap with rate limiting. + if t.Runner == "gemini" { + return executeWithRateLimit(ctx, model, prompt, func() (bool, int, error) { + return execAgent(ctx, t.Runner, model, prompt, dir, logPath) + }) + } + + return execAgent(ctx, t.Runner, model, prompt, dir, logPath) +} + +func execAgent(ctx context.Context, runner, model, prompt, dir, logPath string) (bool, int, error) { + var cmd *exec.Cmd + + switch runner { + case "codex": + cmd = exec.CommandContext(ctx, "codex", "exec", "--full-auto", prompt) + case "gemini": + args := []string{"-p", "-", "-y", "-m", model} + cmd = exec.CommandContext(ctx, "gemini", args...) + cmd.Stdin = strings.NewReader(prompt) + default: // claude + cmd = exec.CommandContext(ctx, "claude", "-p", "--model", model, "--dangerously-skip-permissions", "--output-format", "text") + cmd.Stdin = strings.NewReader(prompt) + } + + cmd.Dir = dir + + f, err := os.Create(logPath) + if err != nil { + return false, -1, err + } + defer f.Close() + + cmd.Stdout = f + cmd.Stderr = f + + if err := cmd.Run(); err != nil { + exitCode := -1 + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } + return false, exitCode, err + } + + return true, 0, nil +} + +func reportToForge(t dispatchTicket, success bool, body string) { + url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments", + strings.TrimSuffix(t.ForgeURL, "/"), t.RepoOwner, t.RepoName, t.IssueNumber) + + payload := map[string]string{"body": body} + jsonBody, _ := json.Marshal(payload) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody)) + if err != nil { + log.Error("Failed to create request", "err", err) + return + } + req.Header.Set("Authorization", "token "+t.ForgeToken) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + log.Error("Failed to report to Forge", "err", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + log.Warn("Forge reported error", "status", resp.Status) + } +} + +func moveToDone(paths runnerPaths, activePath, fileName string) { + donePath := filepath.Join(paths.done, fileName) + if err := os.Rename(activePath, donePath); err != nil { + log.Error("Failed to move ticket to done", "err", err) + } +} + +func ensureDispatchDirs(p runnerPaths) error { + dirs := []string{p.queue, p.active, p.done, p.logs, p.jobs} + for _, d := range dirs { + if err := os.MkdirAll(d, 0755); err != nil { + return fmt.Errorf("mkdir %s failed: %w", d, err) + } + } + return nil +} + +func acquireLock(lockPath string) error { + if data, err := os.ReadFile(lockPath); err == nil { + pidStr := strings.TrimSpace(string(data)) + pid, _ := strconv.Atoi(pidStr) + if isProcessAlive(pid) { + return fmt.Errorf("locked by PID %d", pid) + } + log.Info("Removing stale lock", "pid", pid) + _ = os.Remove(lockPath) + } + + return os.WriteFile(lockPath, []byte(fmt.Sprintf("%d", os.Getpid())), 0644) +} + +func releaseLock(lockPath string) { + _ = os.Remove(lockPath) +} + +func isProcessAlive(pid int) bool { + if pid <= 0 { + return false + } + process, err := os.FindProcess(pid) + if err != nil { + return false + } + return process.Signal(syscall.Signal(0)) == nil +} + +func pickOldestTicket(queueDir string) (string, error) { + entries, err := os.ReadDir(queueDir) + if err != nil { + return "", err + } + + var tickets []string + for _, e := range entries { + if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") && strings.HasSuffix(e.Name(), ".json") { + tickets = append(tickets, filepath.Join(queueDir, e.Name())) + } + } + + if len(tickets) == 0 { + return "", nil + } + + sort.Strings(tickets) + return tickets[0], nil +} diff --git a/internal/cmd/ai/cmd_ratelimits.go b/internal/cmd/ai/cmd_ratelimits.go new file mode 100644 index 0000000..fa05a65 --- /dev/null +++ b/internal/cmd/ai/cmd_ratelimits.go @@ -0,0 +1,213 @@ +package ai + +import ( + "fmt" + "os" + "strconv" + "text/tabwriter" + "time" + + "github.com/host-uk/core/pkg/cli" + "github.com/host-uk/core/pkg/config" + "github.com/host-uk/core/pkg/ratelimit" +) + +// AddRateLimitCommands registers the 'ratelimits' subcommand group under 'ai'. +func AddRateLimitCommands(parent *cli.Command) { + rlCmd := &cli.Command{ + Use: "ratelimits", + Short: "Manage Gemini API rate limits", + } + + rlCmd.AddCommand(rlShowCmd()) + rlCmd.AddCommand(rlResetCmd()) + rlCmd.AddCommand(rlCountCmd()) + rlCmd.AddCommand(rlConfigCmd()) + rlCmd.AddCommand(rlCheckCmd()) + + parent.AddCommand(rlCmd) +} + +func rlShowCmd() *cli.Command { + return &cli.Command{ + Use: "show", + Short: "Show current rate limit usage", + RunE: func(cmd *cli.Command, args []string) error { + rl, err := ratelimit.New() + if err != nil { + return err + } + if err := rl.Load(); err != nil { + return err + } + + stats := rl.AllStats() + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "MODEL\tRPM\tTPM\tRPD\tSTATUS") + + for model, s := range stats { + rpmStr := fmt.Sprintf("%d/%s", s.RPM, formatLimit(s.MaxRPM)) + tpmStr := fmt.Sprintf("%d/%s", s.TPM, formatLimit(s.MaxTPM)) + rpdStr := fmt.Sprintf("%d/%s", s.RPD, formatLimit(s.MaxRPD)) + + status := "OK" + if (s.MaxRPM > 0 && s.RPM >= s.MaxRPM) || + (s.MaxTPM > 0 && s.TPM >= s.MaxTPM) || + (s.MaxRPD > 0 && s.RPD >= s.MaxRPD) { + status = "LIMITED" + } + + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", model, rpmStr, tpmStr, rpdStr, status) + } + w.Flush() + return nil + }, + } +} + +func rlResetCmd() *cli.Command { + return &cli.Command{ + Use: "reset [model]", + Short: "Reset usage counters for a model (or all)", + RunE: func(cmd *cli.Command, args []string) error { + rl, err := ratelimit.New() + if err != nil { + return err + } + if err := rl.Load(); err != nil { + return err + } + + model := "" + if len(args) > 0 { + model = args[0] + } + + rl.Reset(model) + if err := rl.Persist(); err != nil { + return err + } + + if model == "" { + fmt.Println("Reset stats for all models.") + } else { + fmt.Printf("Reset stats for model %q.\n", model) + } + return nil + }, + } +} + +func rlCountCmd() *cli.Command { + return &cli.Command{ + Use: "count ", + Short: "Count tokens for text using Gemini API", + Args: cli.ExactArgs(2), + RunE: func(cmd *cli.Command, args []string) error { + model := args[0] + text := args[1] + + cfg, err := config.New() + if err != nil { + return err + } + + var apiKey string + if err := cfg.Get("agentci.gemini_api_key", &apiKey); err != nil || apiKey == "" { + apiKey = os.Getenv("GEMINI_API_KEY") + } + if apiKey == "" { + return fmt.Errorf("GEMINI_API_KEY not found in config or env") + } + + count, err := ratelimit.CountTokens(apiKey, model, text) + if err != nil { + return err + } + + fmt.Printf("Model: %s\nTokens: %d\n", model, count) + return nil + }, + } +} + +func rlConfigCmd() *cli.Command { + return &cli.Command{ + Use: "config", + Short: "Show configured quotas", + RunE: func(cmd *cli.Command, args []string) error { + rl, err := ratelimit.New() + if err != nil { + return err + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "MODEL\tMAX RPM\tMAX TPM\tMAX RPD") + + for model, q := range rl.Quotas { + fmt.Fprintf(w, "%s\t%s\t%s\t%s\n", + model, + formatLimit(q.MaxRPM), + formatLimit(q.MaxTPM), + formatLimit(q.MaxRPD)) + } + w.Flush() + return nil + }, + } +} + +func rlCheckCmd() *cli.Command { + return &cli.Command{ + Use: "check ", + Short: "Check rate limit capacity for a model", + Args: cli.ExactArgs(2), + RunE: func(cmd *cli.Command, args []string) error { + model := args[0] + tokens, err := strconv.Atoi(args[1]) + if err != nil { + return fmt.Errorf("invalid token count: %w", err) + } + + rl, err := ratelimit.New() + if err != nil { + return err + } + if err := rl.Load(); err != nil { + fmt.Printf("Warning: could not load existing state: %v\n", err) + } + + stats := rl.Stats(model) + canSend := rl.CanSend(model, tokens) + + status := "RATE LIMITED" + if canSend { + status = "OK" + } + + fmt.Printf("Model: %s\n", model) + fmt.Printf("Request Cost: %d tokens\n", tokens) + fmt.Printf("Status: %s\n", status) + fmt.Printf("\nCurrent Usage (1m window):\n") + fmt.Printf(" RPM: %d / %s\n", stats.RPM, formatLimit(stats.MaxRPM)) + fmt.Printf(" TPM: %d / %s\n", stats.TPM, formatLimit(stats.MaxTPM)) + fmt.Printf(" RPD: %d / %s (reset: %s)\n", stats.RPD, formatLimit(stats.MaxRPD), stats.DayStart.Format(time.RFC3339)) + + return nil + }, + } +} + +func formatLimit(limit int) string { + if limit == 0 { + return "∞" + } + if limit >= 1000000 { + return fmt.Sprintf("%dM", limit/1000000) + } + if limit >= 1000 { + return fmt.Sprintf("%dK", limit/1000) + } + return fmt.Sprintf("%d", limit) +} diff --git a/internal/cmd/ai/ratelimit_dispatch.go b/internal/cmd/ai/ratelimit_dispatch.go new file mode 100644 index 0000000..20a20da --- /dev/null +++ b/internal/cmd/ai/ratelimit_dispatch.go @@ -0,0 +1,49 @@ +package ai + +import ( + "context" + + "github.com/host-uk/core/pkg/log" + "github.com/host-uk/core/pkg/ratelimit" +) + +// executeWithRateLimit wraps an agent execution with rate limiting logic. +// It estimates token usage, waits for capacity, executes the runner, and records usage. +func executeWithRateLimit(ctx context.Context, model, prompt string, runner func() (bool, int, error)) (bool, int, error) { + rl, err := ratelimit.New() + if err != nil { + log.Warn("Failed to initialize rate limiter, proceeding without limits", "error", err) + return runner() + } + + if err := rl.Load(); err != nil { + log.Warn("Failed to load rate limit state", "error", err) + } + + // Estimate tokens from prompt length (1 token ≈ 4 chars) + estTokens := len(prompt) / 4 + if estTokens == 0 { + estTokens = 1 + } + + log.Info("Checking rate limits", "model", model, "est_tokens", estTokens) + + if err := rl.WaitForCapacity(ctx, model, estTokens); err != nil { + return false, -1, err + } + + success, exitCode, runErr := runner() + + // Record usage with conservative output estimate (actual tokens unknown from shell runner). + outputEst := estTokens / 10 + if outputEst < 50 { + outputEst = 50 + } + rl.RecordUsage(model, estTokens, outputEst) + + if err := rl.Persist(); err != nil { + log.Warn("Failed to persist rate limit state", "error", err) + } + + return success, exitCode, runErr +} diff --git a/pkg/jobrunner/handlers/completion.go b/pkg/jobrunner/handlers/completion.go new file mode 100644 index 0000000..8078389 --- /dev/null +++ b/pkg/jobrunner/handlers/completion.go @@ -0,0 +1,87 @@ +package handlers + +import ( + "context" + "fmt" + "time" + + "github.com/host-uk/core/pkg/forge" + "github.com/host-uk/core/pkg/jobrunner" +) + +const ( + ColorAgentComplete = "#0e8a16" // Green +) + +// CompletionHandler manages issue state when an agent finishes work. +type CompletionHandler struct { + forge *forge.Client +} + +// NewCompletionHandler creates a handler for agent completion events. +func NewCompletionHandler(client *forge.Client) *CompletionHandler { + return &CompletionHandler{ + forge: client, + } +} + +// Name returns the handler identifier. +func (h *CompletionHandler) Name() string { + return "completion" +} + +// Match returns true if the signal indicates an agent has finished a task. +func (h *CompletionHandler) Match(signal *jobrunner.PipelineSignal) bool { + return signal.Type == "agent_completion" +} + +// Execute updates the issue labels based on the completion status. +func (h *CompletionHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) { + start := time.Now() + + // Remove in-progress label. + if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil { + _ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID) + } + + if signal.Success { + completeLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentComplete, ColorAgentComplete) + if err != nil { + return nil, fmt.Errorf("ensure label %s: %w", LabelAgentComplete, err) + } + + if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{completeLabel.ID}); err != nil { + return nil, fmt.Errorf("add completed label: %w", err) + } + + if signal.Message != "" { + _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), signal.Message) + } + } else { + failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed) + if err != nil { + return nil, fmt.Errorf("ensure label %s: %w", LabelAgentFailed, err) + } + + if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}); err != nil { + return nil, fmt.Errorf("add failed label: %w", err) + } + + msg := "Agent reported failure." + if signal.Error != "" { + msg += fmt.Sprintf("\n\nError: %s", signal.Error) + } + _ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), msg) + } + + return &jobrunner.ActionResult{ + Action: "completion", + RepoOwner: signal.RepoOwner, + RepoName: signal.RepoName, + EpicNumber: signal.EpicNumber, + ChildNumber: signal.ChildNumber, + Success: true, + Timestamp: time.Now(), + Duration: time.Since(start), + }, nil +} diff --git a/pkg/ratelimit/ratelimit.go b/pkg/ratelimit/ratelimit.go new file mode 100644 index 0000000..c02adab --- /dev/null +++ b/pkg/ratelimit/ratelimit.go @@ -0,0 +1,382 @@ +package ratelimit + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sync" + "time" + + "gopkg.in/yaml.v3" +) + +// ModelQuota defines the rate limits for a specific model. +type ModelQuota struct { + MaxRPM int `yaml:"max_rpm"` // Requests per minute + MaxTPM int `yaml:"max_tpm"` // Tokens per minute + MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited) +} + +// TokenEntry records a token usage event. +type TokenEntry struct { + Time time.Time `yaml:"time"` + Count int `yaml:"count"` +} + +// UsageStats tracks usage history for a model. +type UsageStats struct { + Requests []time.Time `yaml:"requests"` // Sliding window (1m) + Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m) + DayStart time.Time `yaml:"day_start"` + DayCount int `yaml:"day_count"` +} + +// RateLimiter manages rate limits across multiple models. +type RateLimiter struct { + mu sync.RWMutex + Quotas map[string]ModelQuota `yaml:"quotas"` + State map[string]*UsageStats `yaml:"state"` + filePath string +} + +// New creates a new RateLimiter with default quotas. +func New() (*RateLimiter, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + rl := &RateLimiter{ + Quotas: make(map[string]ModelQuota), + State: make(map[string]*UsageStats), + filePath: filepath.Join(home, ".core", "ratelimits.yaml"), + } + + // Default quotas based on Tier 1 observations (Feb 2026) + rl.Quotas["gemini-3-pro-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-3-flash-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-2.5-pro"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-2.0-flash"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0} // Unlimited RPD + rl.Quotas["gemini-2.0-flash-lite"] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} // Unlimited + + return rl, nil +} + +// Load reads the state from disk. +func (rl *RateLimiter) Load() error { + rl.mu.Lock() + defer rl.mu.Unlock() + + data, err := os.ReadFile(rl.filePath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return err + } + + return yaml.Unmarshal(data, rl) +} + +// Persist writes the state to disk. +func (rl *RateLimiter) Persist() error { + rl.mu.RLock() + defer rl.mu.RUnlock() + + data, err := yaml.Marshal(rl) + if err != nil { + return err + } + + dir := filepath.Dir(rl.filePath) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + return os.WriteFile(rl.filePath, data, 0644) +} + +// prune removes entries older than the sliding window (1 minute). +// Caller must hold lock. +func (rl *RateLimiter) prune(model string) { + stats, ok := rl.State[model] + if !ok { + return + } + + now := time.Now() + window := now.Add(-1 * time.Minute) + + // Prune requests + validReqs := 0 + for _, t := range stats.Requests { + if t.After(window) { + stats.Requests[validReqs] = t + validReqs++ + } + } + stats.Requests = stats.Requests[:validReqs] + + // Prune tokens + validTokens := 0 + for _, t := range stats.Tokens { + if t.Time.After(window) { + stats.Tokens[validTokens] = t + validTokens++ + } + } + stats.Tokens = stats.Tokens[:validTokens] + + // Reset daily counter if day has passed + if now.Sub(stats.DayStart) >= 24*time.Hour { + stats.DayStart = now + stats.DayCount = 0 + } +} + +// CanSend checks if a request can be sent without violating limits. +func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + quota, ok := rl.Quotas[model] + if !ok { + return true // Unknown models are allowed + } + + // Unlimited check + if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 { + return true + } + + // Ensure state exists + if _, ok := rl.State[model]; !ok { + rl.State[model] = &UsageStats{ + DayStart: time.Now(), + } + } + + rl.prune(model) + stats := rl.State[model] + + // Check RPD + if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD { + return false + } + + // Check RPM + if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM { + return false + } + + // Check TPM + if quota.MaxTPM > 0 { + currentTokens := 0 + for _, t := range stats.Tokens { + currentTokens += t.Count + } + if currentTokens+estimatedTokens > quota.MaxTPM { + return false + } + } + + return true +} + +// RecordUsage records a successful API call. +func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) { + rl.mu.Lock() + defer rl.mu.Unlock() + + if _, ok := rl.State[model]; !ok { + rl.State[model] = &UsageStats{ + DayStart: time.Now(), + } + } + + stats := rl.State[model] + now := time.Now() + + stats.Requests = append(stats.Requests, now) + stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) + stats.DayCount++ +} + +// WaitForCapacity blocks until capacity is available or context is cancelled. +func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + if rl.CanSend(model, tokens) { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + // check again + } + } +} + +// Reset clears stats for a model (or all if model is empty). +func (rl *RateLimiter) Reset(model string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + if model == "" { + rl.State = make(map[string]*UsageStats) + } else { + delete(rl.State, model) + } +} + +// ModelStats represents a snapshot of usage. +type ModelStats struct { + RPM int + MaxRPM int + TPM int + MaxTPM int + RPD int + MaxRPD int + DayStart time.Time +} + +// Stats returns current stats for a model. +func (rl *RateLimiter) Stats(model string) ModelStats { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.prune(model) + + stats := ModelStats{} + quota, ok := rl.Quotas[model] + if ok { + stats.MaxRPM = quota.MaxRPM + stats.MaxTPM = quota.MaxTPM + stats.MaxRPD = quota.MaxRPD + } + + if s, ok := rl.State[model]; ok { + stats.RPM = len(s.Requests) + stats.RPD = s.DayCount + stats.DayStart = s.DayStart + for _, t := range s.Tokens { + stats.TPM += t.Count + } + } + + return stats +} + +// AllStats returns stats for all tracked models. +func (rl *RateLimiter) AllStats() map[string]ModelStats { + rl.mu.Lock() + defer rl.mu.Unlock() + + result := make(map[string]ModelStats) + + // Collect all model names + for m := range rl.Quotas { + result[m] = ModelStats{} + } + for m := range rl.State { + result[m] = ModelStats{} + } + + now := time.Now() + window := now.Add(-1 * time.Minute) + + for m := range result { + // Prune inline + if s, ok := rl.State[m]; ok { + validReqs := 0 + for _, t := range s.Requests { + if t.After(window) { + s.Requests[validReqs] = t + validReqs++ + } + } + s.Requests = s.Requests[:validReqs] + + validTokens := 0 + for _, t := range s.Tokens { + if t.Time.After(window) { + s.Tokens[validTokens] = t + validTokens++ + } + } + s.Tokens = s.Tokens[:validTokens] + + if now.Sub(s.DayStart) >= 24*time.Hour { + s.DayStart = now + s.DayCount = 0 + } + } + + ms := ModelStats{} + if q, ok := rl.Quotas[m]; ok { + ms.MaxRPM = q.MaxRPM + ms.MaxTPM = q.MaxTPM + ms.MaxRPD = q.MaxRPD + } + if s, ok := rl.State[m]; ok { + ms.RPM = len(s.Requests) + ms.RPD = s.DayCount + ms.DayStart = s.DayStart + for _, t := range s.Tokens { + ms.TPM += t.Count + } + } + result[m] = ms + } + + return result +} + +// CountTokens calls the Google API to count tokens for a prompt. +func CountTokens(apiKey, model, text string) (int, error) { + url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens?key=%s", model, apiKey) + + reqBody := map[string]any{ + "contents": []any{ + map[string]any{ + "parts": []any{ + map[string]string{"text": text}, + }, + }, + }, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return 0, err + } + + resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody)) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return 0, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return 0, err + } + + return result.TotalTokens, nil +} diff --git a/pkg/ratelimit/ratelimit_test.go b/pkg/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..1247960 --- /dev/null +++ b/pkg/ratelimit/ratelimit_test.go @@ -0,0 +1,176 @@ +package ratelimit + +import ( + "context" + "path/filepath" + "testing" + "time" +) + +func TestCanSend_Good(t *testing.T) { + rl, _ := New() + rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml") + + model := "test-model" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100} + + if !rl.CanSend(model, 100) { + t.Errorf("Expected CanSend to return true for fresh state") + } +} + +func TestCanSend_RPMExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-rpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + + if rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return false after exceeding RPM") + } +} + +func TestCanSend_TPMExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-tpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100} + + rl.RecordUsage(model, 50, 40) // 90 tokens used + + if rl.CanSend(model, 20) { // 90 + 20 = 110 > 100 + t.Errorf("Expected CanSend to return false when estimated tokens exceed TPM") + } +} + +func TestCanSend_RPDExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-rpd" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + + if rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return false after exceeding RPD") + } +} + +func TestCanSend_UnlimitedModel_Good(t *testing.T) { + rl, _ := New() + model := "test-unlimited" + rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} + + // Should always be allowed + for i := 0; i < 1000; i++ { + rl.RecordUsage(model, 100, 100) + } + if !rl.CanSend(model, 999999) { + t.Errorf("Expected unlimited model to always allow sends") + } +} + +func TestRecordUsage_PrunesOldEntries_Good(t *testing.T) { + rl, _ := New() + model := "test-prune" + rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100} + + // Manually inject old data + oldTime := time.Now().Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{oldTime, oldTime, oldTime}, + Tokens: []TokenEntry{ + {Time: oldTime, Count: 100}, + {Time: oldTime, Count: 100}, + }, + DayStart: time.Now(), + } + + // CanSend triggers prune + if !rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return true after pruning old entries") + } + + stats := rl.State[model] + if len(stats.Requests) != 0 { + t.Errorf("Expected 0 requests after pruning old entries, got %d", len(stats.Requests)) + } +} + +func TestPersistAndLoad_Good(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ratelimits.yaml") + + rl1, _ := New() + rl1.filePath = path + model := "persist-test" + rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500} + rl1.RecordUsage(model, 100, 100) + + if err := rl1.Persist(); err != nil { + t.Fatalf("Persist failed: %v", err) + } + + rl2, _ := New() + rl2.filePath = path + if err := rl2.Load(); err != nil { + t.Fatalf("Load failed: %v", err) + } + + stats := rl2.Stats(model) + if stats.RPM != 1 { + t.Errorf("Expected RPM 1 after load, got %d", stats.RPM) + } + if stats.TPM != 200 { + t.Errorf("Expected TPM 200 after load, got %d", stats.TPM) + } +} + +func TestWaitForCapacity_Ugly(t *testing.T) { + rl, _ := New() + model := "wait-test" + rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) // Use up the 1 RPM + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := rl.WaitForCapacity(ctx, model, 10) + if err != context.DeadlineExceeded { + t.Errorf("Expected DeadlineExceeded, got %v", err) + } +} + +func TestDefaultQuotas_Good(t *testing.T) { + rl, _ := New() + expected := []string{ + "gemini-3-pro-preview", + "gemini-3-flash-preview", + "gemini-2.0-flash", + } + for _, m := range expected { + if _, ok := rl.Quotas[m]; !ok { + t.Errorf("Expected default quota for %s", m) + } + } +} + +func TestAllStats_Good(t *testing.T) { + rl, _ := New() + rl.RecordUsage("gemini-3-pro-preview", 1000, 500) + + all := rl.AllStats() + if len(all) < 5 { + t.Errorf("Expected at least 5 models in AllStats, got %d", len(all)) + } + + pro := all["gemini-3-pro-preview"] + if pro.RPM != 1 { + t.Errorf("Expected RPM 1 for pro, got %d", pro.RPM) + } + if pro.TPM != 1500 { + t.Errorf("Expected TPM 1500 for pro, got %d", pro.TPM) + } +}