feat(agentci): rate limiting and native Go dispatch runner
Adds pkg/ratelimit for Gemini API rate limiting with sliding window
(RPM/TPM/RPD), persistent state, and token counting. Replaces the
bash agent-runner.sh with a native Go implementation under
`core ai dispatch {run,watch,status}` for local queue processing.
Rate limiting:
- Per-model quotas (RPM, TPM, RPD) with 1-minute sliding window
- WaitForCapacity blocks until capacity available or context cancelled
- Persistent state in ~/.core/ratelimits.yaml
- Default quotas for Gemini 3 Pro/Flash, 2.5 Pro, 2.0 Flash/Lite
- CountTokens helper calls Google tokenizer API
- CLI: core ai ratelimits {show,reset,count,config,check}
Dispatch runner:
- core ai dispatch run — process single ticket from queue
- core ai dispatch watch — daemon mode with configurable interval
- core ai dispatch status — show queue/active/done counts
- Supports claude/codex/gemini runners with rate-limited Gemini
- File-based locking with stale PID detection
- Completion handler updates issue labels on success/failure
Closes #42
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d92762ecdc
commit
3ccb67bddd
7 changed files with 1411 additions and 0 deletions
|
|
@ -69,6 +69,12 @@ func initCommands() {
|
||||||
|
|
||||||
// Add agent management commands (core ai agent ...)
|
// Add agent management commands (core ai agent ...)
|
||||||
AddAgentCommands(aiCmd)
|
AddAgentCommands(aiCmd)
|
||||||
|
|
||||||
|
// Add rate limit management commands (core ai ratelimits ...)
|
||||||
|
AddRateLimitCommands(aiCmd)
|
||||||
|
|
||||||
|
// Add dispatch commands (core ai dispatch run/watch/status)
|
||||||
|
AddDispatchCommands(aiCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddAICommands registers the 'ai' command and all subcommands.
|
// AddAICommands registers the 'ai' command and all subcommands.
|
||||||
|
|
|
||||||
498
internal/cmd/ai/cmd_dispatch.go
Normal file
498
internal/cmd/ai/cmd_dispatch.go
Normal file
|
|
@ -0,0 +1,498 @@
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/host-uk/core/pkg/cli"
|
||||||
|
"github.com/host-uk/core/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddDispatchCommands registers the 'dispatch' subcommand group under 'ai'.
|
||||||
|
// These commands run ON the agent machine to process the work queue.
|
||||||
|
func AddDispatchCommands(parent *cli.Command) {
|
||||||
|
dispatchCmd := &cli.Command{
|
||||||
|
Use: "dispatch",
|
||||||
|
Short: "Agent work queue processor (runs on agent machine)",
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatchCmd.AddCommand(dispatchRunCmd())
|
||||||
|
dispatchCmd.AddCommand(dispatchWatchCmd())
|
||||||
|
dispatchCmd.AddCommand(dispatchStatusCmd())
|
||||||
|
|
||||||
|
parent.AddCommand(dispatchCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchTicket represents the work item JSON structure.
|
||||||
|
type dispatchTicket struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
RepoOwner string `json:"repo_owner"`
|
||||||
|
RepoName string `json:"repo_name"`
|
||||||
|
IssueNumber int `json:"issue_number"`
|
||||||
|
IssueTitle string `json:"issue_title"`
|
||||||
|
IssueBody string `json:"issue_body"`
|
||||||
|
TargetBranch string `json:"target_branch"`
|
||||||
|
EpicNumber int `json:"epic_number"`
|
||||||
|
ForgeURL string `json:"forge_url"`
|
||||||
|
ForgeToken string `json:"forge_token"`
|
||||||
|
ForgeUser string `json:"forgejo_user"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Runner string `json:"runner"`
|
||||||
|
Timeout string `json:"timeout"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultWorkDir = "ai-work"
|
||||||
|
lockFileName = ".runner.lock"
|
||||||
|
)
|
||||||
|
|
||||||
|
type runnerPaths struct {
|
||||||
|
root string
|
||||||
|
queue string
|
||||||
|
active string
|
||||||
|
done string
|
||||||
|
logs string
|
||||||
|
jobs string
|
||||||
|
lock string
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPaths(baseDir string) runnerPaths {
|
||||||
|
if baseDir == "" {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
baseDir = filepath.Join(home, defaultWorkDir)
|
||||||
|
}
|
||||||
|
return runnerPaths{
|
||||||
|
root: baseDir,
|
||||||
|
queue: filepath.Join(baseDir, "queue"),
|
||||||
|
active: filepath.Join(baseDir, "active"),
|
||||||
|
done: filepath.Join(baseDir, "done"),
|
||||||
|
logs: filepath.Join(baseDir, "logs"),
|
||||||
|
jobs: filepath.Join(baseDir, "jobs"),
|
||||||
|
lock: filepath.Join(baseDir, lockFileName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatchRunCmd() *cli.Command {
|
||||||
|
cmd := &cli.Command{
|
||||||
|
Use: "run",
|
||||||
|
Short: "Process a single ticket from the queue",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
workDir, _ := cmd.Flags().GetString("work-dir")
|
||||||
|
paths := getPaths(workDir)
|
||||||
|
|
||||||
|
if err := ensureDispatchDirs(paths); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := acquireLock(paths.lock); err != nil {
|
||||||
|
log.Info("Runner locked, skipping run", "lock", paths.lock)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer releaseLock(paths.lock)
|
||||||
|
|
||||||
|
ticketFile, err := pickOldestTicket(paths.queue)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ticketFile == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return processTicket(paths, ticketFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatchWatchCmd() *cli.Command {
|
||||||
|
cmd := &cli.Command{
|
||||||
|
Use: "watch",
|
||||||
|
Short: "Run as a daemon, polling the queue",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
workDir, _ := cmd.Flags().GetString("work-dir")
|
||||||
|
interval, _ := cmd.Flags().GetDuration("interval")
|
||||||
|
paths := getPaths(workDir)
|
||||||
|
|
||||||
|
if err := ensureDispatchDirs(paths); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Starting dispatch watcher", "dir", paths.root, "interval", interval)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
|
||||||
|
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
runCycle(paths)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
runCycle(paths)
|
||||||
|
case <-sigChan:
|
||||||
|
log.Info("Shutting down watcher...")
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
|
||||||
|
cmd.Flags().Duration("interval", 5*time.Minute, "Polling interval")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatchStatusCmd() *cli.Command {
|
||||||
|
cmd := &cli.Command{
|
||||||
|
Use: "status",
|
||||||
|
Short: "Show runner status",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
workDir, _ := cmd.Flags().GetString("work-dir")
|
||||||
|
paths := getPaths(workDir)
|
||||||
|
|
||||||
|
lockStatus := "IDLE"
|
||||||
|
if data, err := os.ReadFile(paths.lock); err == nil {
|
||||||
|
pidStr := strings.TrimSpace(string(data))
|
||||||
|
pid, _ := strconv.Atoi(pidStr)
|
||||||
|
if isProcessAlive(pid) {
|
||||||
|
lockStatus = fmt.Sprintf("RUNNING (PID %d)", pid)
|
||||||
|
} else {
|
||||||
|
lockStatus = fmt.Sprintf("STALE (PID %d)", pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
countFiles := func(dir string) int {
|
||||||
|
entries, _ := os.ReadDir(dir)
|
||||||
|
count := 0
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("=== Agent Dispatch Status ===")
|
||||||
|
fmt.Printf("Work Dir: %s\n", paths.root)
|
||||||
|
fmt.Printf("Status: %s\n", lockStatus)
|
||||||
|
fmt.Printf("Queue: %d\n", countFiles(paths.queue))
|
||||||
|
fmt.Printf("Active: %d\n", countFiles(paths.active))
|
||||||
|
fmt.Printf("Done: %d\n", countFiles(paths.done))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cmd.Flags().String("work-dir", "", "Working directory (default: ~/ai-work)")
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func runCycle(paths runnerPaths) {
|
||||||
|
if err := acquireLock(paths.lock); err != nil {
|
||||||
|
log.Debug("Runner locked, skipping cycle")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer releaseLock(paths.lock)
|
||||||
|
|
||||||
|
ticketFile, err := pickOldestTicket(paths.queue)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to pick ticket", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ticketFile == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processTicket(paths, ticketFile); err != nil {
|
||||||
|
log.Error("Failed to process ticket", "file", ticketFile, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processTicket(paths runnerPaths, ticketPath string) error {
|
||||||
|
fileName := filepath.Base(ticketPath)
|
||||||
|
log.Info("Processing ticket", "file", fileName)
|
||||||
|
|
||||||
|
activePath := filepath.Join(paths.active, fileName)
|
||||||
|
if err := os.Rename(ticketPath, activePath); err != nil {
|
||||||
|
return fmt.Errorf("failed to move ticket to active: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(activePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read ticket: %w", err)
|
||||||
|
}
|
||||||
|
var t dispatchTicket
|
||||||
|
if err := json.Unmarshal(data, &t); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal ticket: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
jobDir := filepath.Join(paths.jobs, fmt.Sprintf("%s-%s-%d", t.RepoOwner, t.RepoName, t.IssueNumber))
|
||||||
|
repoDir := filepath.Join(jobDir, t.RepoName)
|
||||||
|
if err := os.MkdirAll(jobDir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := prepareRepo(t, repoDir); err != nil {
|
||||||
|
reportToForge(t, false, fmt.Sprintf("Git setup failed: %v", err))
|
||||||
|
moveToDone(paths, activePath, fileName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := buildPrompt(t)
|
||||||
|
|
||||||
|
logFile := filepath.Join(paths.logs, fmt.Sprintf("%s-%s-%d.log", t.RepoOwner, t.RepoName, t.IssueNumber))
|
||||||
|
success, exitCode, runErr := runAgent(t, prompt, repoDir, logFile)
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Agent completed work on #%d. Exit code: %d.", t.IssueNumber, exitCode)
|
||||||
|
if !success {
|
||||||
|
msg = fmt.Sprintf("Agent failed on #%d (exit code: %d). Check logs on agent machine.", t.IssueNumber, exitCode)
|
||||||
|
if runErr != nil {
|
||||||
|
msg += fmt.Sprintf(" Error: %v", runErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reportToForge(t, success, msg)
|
||||||
|
|
||||||
|
moveToDone(paths, activePath, fileName)
|
||||||
|
log.Info("Ticket complete", "id", t.ID, "success", success)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareRepo(t dispatchTicket, repoDir string) error {
|
||||||
|
user := t.ForgeUser
|
||||||
|
if user == "" {
|
||||||
|
host, _ := os.Hostname()
|
||||||
|
user = fmt.Sprintf("%s-%s", host, os.Getenv("USER"))
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanURL := strings.TrimPrefix(t.ForgeURL, "https://")
|
||||||
|
cleanURL = strings.TrimPrefix(cleanURL, "http://")
|
||||||
|
cloneURL := fmt.Sprintf("https://%s:%s@%s/%s/%s.git", user, t.ForgeToken, cleanURL, t.RepoOwner, t.RepoName)
|
||||||
|
|
||||||
|
if _, err := os.Stat(filepath.Join(repoDir, ".git")); err == nil {
|
||||||
|
log.Info("Updating existing repo", "dir", repoDir)
|
||||||
|
cmds := [][]string{
|
||||||
|
{"git", "fetch", "origin"},
|
||||||
|
{"git", "checkout", t.TargetBranch},
|
||||||
|
{"git", "pull", "origin", t.TargetBranch},
|
||||||
|
}
|
||||||
|
for _, args := range cmds {
|
||||||
|
cmd := exec.Command(args[0], args[1:]...)
|
||||||
|
cmd.Dir = repoDir
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
if args[1] == "checkout" {
|
||||||
|
createCmd := exec.Command("git", "checkout", "-b", t.TargetBranch, "origin/"+t.TargetBranch)
|
||||||
|
createCmd.Dir = repoDir
|
||||||
|
if _, err2 := createCmd.CombinedOutput(); err2 == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("git command %v failed: %s", args, string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Info("Cloning repo", "url", t.RepoOwner+"/"+t.RepoName)
|
||||||
|
cmd := exec.Command("git", "clone", "-b", t.TargetBranch, cloneURL, repoDir)
|
||||||
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("git clone failed: %s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPrompt(t dispatchTicket) string {
|
||||||
|
return fmt.Sprintf(`You are working on issue #%d in %s/%s.
|
||||||
|
|
||||||
|
Title: %s
|
||||||
|
|
||||||
|
Description:
|
||||||
|
%s
|
||||||
|
|
||||||
|
The repo is cloned at the current directory on branch '%s'.
|
||||||
|
Create a feature branch from '%s', make minimal targeted changes, commit referencing #%d, and push.
|
||||||
|
Then create a PR targeting '%s' using the forgejo MCP tools or git push.`,
|
||||||
|
t.IssueNumber, t.RepoOwner, t.RepoName,
|
||||||
|
t.IssueTitle,
|
||||||
|
t.IssueBody,
|
||||||
|
t.TargetBranch,
|
||||||
|
t.TargetBranch, t.IssueNumber,
|
||||||
|
t.TargetBranch,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runAgent(t dispatchTicket, prompt, dir, logPath string) (bool, int, error) {
|
||||||
|
timeout := 30 * time.Minute
|
||||||
|
if t.Timeout != "" {
|
||||||
|
if d, err := time.ParseDuration(t.Timeout); err == nil {
|
||||||
|
timeout = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
model := t.Model
|
||||||
|
if model == "" {
|
||||||
|
model = "sonnet"
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Running agent", "runner", t.Runner, "model", model)
|
||||||
|
|
||||||
|
// For Gemini runner, wrap with rate limiting.
|
||||||
|
if t.Runner == "gemini" {
|
||||||
|
return executeWithRateLimit(ctx, model, prompt, func() (bool, int, error) {
|
||||||
|
return execAgent(ctx, t.Runner, model, prompt, dir, logPath)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return execAgent(ctx, t.Runner, model, prompt, dir, logPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func execAgent(ctx context.Context, runner, model, prompt, dir, logPath string) (bool, int, error) {
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
switch runner {
|
||||||
|
case "codex":
|
||||||
|
cmd = exec.CommandContext(ctx, "codex", "exec", "--full-auto", prompt)
|
||||||
|
case "gemini":
|
||||||
|
args := []string{"-p", "-", "-y", "-m", model}
|
||||||
|
cmd = exec.CommandContext(ctx, "gemini", args...)
|
||||||
|
cmd.Stdin = strings.NewReader(prompt)
|
||||||
|
default: // claude
|
||||||
|
cmd = exec.CommandContext(ctx, "claude", "-p", "--model", model, "--dangerously-skip-permissions", "--output-format", "text")
|
||||||
|
cmd.Stdin = strings.NewReader(prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Dir = dir
|
||||||
|
|
||||||
|
f, err := os.Create(logPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, -1, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
cmd.Stdout = f
|
||||||
|
cmd.Stderr = f
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
exitCode := -1
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
exitCode = exitErr.ExitCode()
|
||||||
|
}
|
||||||
|
return false, exitCode, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func reportToForge(t dispatchTicket, success bool, body string) {
|
||||||
|
url := fmt.Sprintf("%s/api/v1/repos/%s/%s/issues/%d/comments",
|
||||||
|
strings.TrimSuffix(t.ForgeURL, "/"), t.RepoOwner, t.RepoName, t.IssueNumber)
|
||||||
|
|
||||||
|
payload := map[string]string{"body": body}
|
||||||
|
jsonBody, _ := json.Marshal(payload)
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to create request", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "token "+t.ForgeToken)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 10 * time.Second}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Failed to report to Forge", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode >= 300 {
|
||||||
|
log.Warn("Forge reported error", "status", resp.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func moveToDone(paths runnerPaths, activePath, fileName string) {
|
||||||
|
donePath := filepath.Join(paths.done, fileName)
|
||||||
|
if err := os.Rename(activePath, donePath); err != nil {
|
||||||
|
log.Error("Failed to move ticket to done", "err", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureDispatchDirs(p runnerPaths) error {
|
||||||
|
dirs := []string{p.queue, p.active, p.done, p.logs, p.jobs}
|
||||||
|
for _, d := range dirs {
|
||||||
|
if err := os.MkdirAll(d, 0755); err != nil {
|
||||||
|
return fmt.Errorf("mkdir %s failed: %w", d, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func acquireLock(lockPath string) error {
|
||||||
|
if data, err := os.ReadFile(lockPath); err == nil {
|
||||||
|
pidStr := strings.TrimSpace(string(data))
|
||||||
|
pid, _ := strconv.Atoi(pidStr)
|
||||||
|
if isProcessAlive(pid) {
|
||||||
|
return fmt.Errorf("locked by PID %d", pid)
|
||||||
|
}
|
||||||
|
log.Info("Removing stale lock", "pid", pid)
|
||||||
|
_ = os.Remove(lockPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(lockPath, []byte(fmt.Sprintf("%d", os.Getpid())), 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func releaseLock(lockPath string) {
|
||||||
|
_ = os.Remove(lockPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProcessAlive(pid int) bool {
|
||||||
|
if pid <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
process, err := os.FindProcess(pid)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return process.Signal(syscall.Signal(0)) == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func pickOldestTicket(queueDir string) (string, error) {
|
||||||
|
entries, err := os.ReadDir(queueDir)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
var tickets []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if !e.IsDir() && strings.HasPrefix(e.Name(), "ticket-") && strings.HasSuffix(e.Name(), ".json") {
|
||||||
|
tickets = append(tickets, filepath.Join(queueDir, e.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tickets) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(tickets)
|
||||||
|
return tickets[0], nil
|
||||||
|
}
|
||||||
213
internal/cmd/ai/cmd_ratelimits.go
Normal file
213
internal/cmd/ai/cmd_ratelimits.go
Normal file
|
|
@ -0,0 +1,213 @@
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"text/tabwriter"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/host-uk/core/pkg/cli"
|
||||||
|
"github.com/host-uk/core/pkg/config"
|
||||||
|
"github.com/host-uk/core/pkg/ratelimit"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AddRateLimitCommands registers the 'ratelimits' subcommand group under 'ai'.
|
||||||
|
func AddRateLimitCommands(parent *cli.Command) {
|
||||||
|
rlCmd := &cli.Command{
|
||||||
|
Use: "ratelimits",
|
||||||
|
Short: "Manage Gemini API rate limits",
|
||||||
|
}
|
||||||
|
|
||||||
|
rlCmd.AddCommand(rlShowCmd())
|
||||||
|
rlCmd.AddCommand(rlResetCmd())
|
||||||
|
rlCmd.AddCommand(rlCountCmd())
|
||||||
|
rlCmd.AddCommand(rlConfigCmd())
|
||||||
|
rlCmd.AddCommand(rlCheckCmd())
|
||||||
|
|
||||||
|
parent.AddCommand(rlCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rlShowCmd() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Use: "show",
|
||||||
|
Short: "Show current rate limit usage",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
rl, err := ratelimit.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rl.Load(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := rl.AllStats()
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "MODEL\tRPM\tTPM\tRPD\tSTATUS")
|
||||||
|
|
||||||
|
for model, s := range stats {
|
||||||
|
rpmStr := fmt.Sprintf("%d/%s", s.RPM, formatLimit(s.MaxRPM))
|
||||||
|
tpmStr := fmt.Sprintf("%d/%s", s.TPM, formatLimit(s.MaxTPM))
|
||||||
|
rpdStr := fmt.Sprintf("%d/%s", s.RPD, formatLimit(s.MaxRPD))
|
||||||
|
|
||||||
|
status := "OK"
|
||||||
|
if (s.MaxRPM > 0 && s.RPM >= s.MaxRPM) ||
|
||||||
|
(s.MaxTPM > 0 && s.TPM >= s.MaxTPM) ||
|
||||||
|
(s.MaxRPD > 0 && s.RPD >= s.MaxRPD) {
|
||||||
|
status = "LIMITED"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\n", model, rpmStr, tpmStr, rpdStr, status)
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rlResetCmd() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Use: "reset [model]",
|
||||||
|
Short: "Reset usage counters for a model (or all)",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
rl, err := ratelimit.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rl.Load(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
model := ""
|
||||||
|
if len(args) > 0 {
|
||||||
|
model = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.Reset(model)
|
||||||
|
if err := rl.Persist(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if model == "" {
|
||||||
|
fmt.Println("Reset stats for all models.")
|
||||||
|
} else {
|
||||||
|
fmt.Printf("Reset stats for model %q.\n", model)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rlCountCmd() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Use: "count <model> <text>",
|
||||||
|
Short: "Count tokens for text using Gemini API",
|
||||||
|
Args: cli.ExactArgs(2),
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
model := args[0]
|
||||||
|
text := args[1]
|
||||||
|
|
||||||
|
cfg, err := config.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiKey string
|
||||||
|
if err := cfg.Get("agentci.gemini_api_key", &apiKey); err != nil || apiKey == "" {
|
||||||
|
apiKey = os.Getenv("GEMINI_API_KEY")
|
||||||
|
}
|
||||||
|
if apiKey == "" {
|
||||||
|
return fmt.Errorf("GEMINI_API_KEY not found in config or env")
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := ratelimit.CountTokens(apiKey, model, text)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model: %s\nTokens: %d\n", model, count)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rlConfigCmd() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Use: "config",
|
||||||
|
Short: "Show configured quotas",
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
rl, err := ratelimit.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "MODEL\tMAX RPM\tMAX TPM\tMAX RPD")
|
||||||
|
|
||||||
|
for model, q := range rl.Quotas {
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%s\t%s\n",
|
||||||
|
model,
|
||||||
|
formatLimit(q.MaxRPM),
|
||||||
|
formatLimit(q.MaxTPM),
|
||||||
|
formatLimit(q.MaxRPD))
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rlCheckCmd() *cli.Command {
|
||||||
|
return &cli.Command{
|
||||||
|
Use: "check <model> <estimated-tokens>",
|
||||||
|
Short: "Check rate limit capacity for a model",
|
||||||
|
Args: cli.ExactArgs(2),
|
||||||
|
RunE: func(cmd *cli.Command, args []string) error {
|
||||||
|
model := args[0]
|
||||||
|
tokens, err := strconv.Atoi(args[1])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid token count: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rl, err := ratelimit.New()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := rl.Load(); err != nil {
|
||||||
|
fmt.Printf("Warning: could not load existing state: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := rl.Stats(model)
|
||||||
|
canSend := rl.CanSend(model, tokens)
|
||||||
|
|
||||||
|
status := "RATE LIMITED"
|
||||||
|
if canSend {
|
||||||
|
status = "OK"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model: %s\n", model)
|
||||||
|
fmt.Printf("Request Cost: %d tokens\n", tokens)
|
||||||
|
fmt.Printf("Status: %s\n", status)
|
||||||
|
fmt.Printf("\nCurrent Usage (1m window):\n")
|
||||||
|
fmt.Printf(" RPM: %d / %s\n", stats.RPM, formatLimit(stats.MaxRPM))
|
||||||
|
fmt.Printf(" TPM: %d / %s\n", stats.TPM, formatLimit(stats.MaxTPM))
|
||||||
|
fmt.Printf(" RPD: %d / %s (reset: %s)\n", stats.RPD, formatLimit(stats.MaxRPD), stats.DayStart.Format(time.RFC3339))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatLimit(limit int) string {
|
||||||
|
if limit == 0 {
|
||||||
|
return "∞"
|
||||||
|
}
|
||||||
|
if limit >= 1000000 {
|
||||||
|
return fmt.Sprintf("%dM", limit/1000000)
|
||||||
|
}
|
||||||
|
if limit >= 1000 {
|
||||||
|
return fmt.Sprintf("%dK", limit/1000)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", limit)
|
||||||
|
}
|
||||||
49
internal/cmd/ai/ratelimit_dispatch.go
Normal file
49
internal/cmd/ai/ratelimit_dispatch.go
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
package ai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/host-uk/core/pkg/log"
|
||||||
|
"github.com/host-uk/core/pkg/ratelimit"
|
||||||
|
)
|
||||||
|
|
||||||
|
// executeWithRateLimit wraps an agent execution with rate limiting logic.
|
||||||
|
// It estimates token usage, waits for capacity, executes the runner, and records usage.
|
||||||
|
func executeWithRateLimit(ctx context.Context, model, prompt string, runner func() (bool, int, error)) (bool, int, error) {
|
||||||
|
rl, err := ratelimit.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("Failed to initialize rate limiter, proceeding without limits", "error", err)
|
||||||
|
return runner()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rl.Load(); err != nil {
|
||||||
|
log.Warn("Failed to load rate limit state", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate tokens from prompt length (1 token ≈ 4 chars)
|
||||||
|
estTokens := len(prompt) / 4
|
||||||
|
if estTokens == 0 {
|
||||||
|
estTokens = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Checking rate limits", "model", model, "est_tokens", estTokens)
|
||||||
|
|
||||||
|
if err := rl.WaitForCapacity(ctx, model, estTokens); err != nil {
|
||||||
|
return false, -1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
success, exitCode, runErr := runner()
|
||||||
|
|
||||||
|
// Record usage with conservative output estimate (actual tokens unknown from shell runner).
|
||||||
|
outputEst := estTokens / 10
|
||||||
|
if outputEst < 50 {
|
||||||
|
outputEst = 50
|
||||||
|
}
|
||||||
|
rl.RecordUsage(model, estTokens, outputEst)
|
||||||
|
|
||||||
|
if err := rl.Persist(); err != nil {
|
||||||
|
log.Warn("Failed to persist rate limit state", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return success, exitCode, runErr
|
||||||
|
}
|
||||||
87
pkg/jobrunner/handlers/completion.go
Normal file
87
pkg/jobrunner/handlers/completion.go
Normal file
|
|
@ -0,0 +1,87 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/host-uk/core/pkg/forge"
|
||||||
|
"github.com/host-uk/core/pkg/jobrunner"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ColorAgentComplete = "#0e8a16" // Green
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompletionHandler manages issue state when an agent finishes work.
|
||||||
|
type CompletionHandler struct {
|
||||||
|
forge *forge.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCompletionHandler creates a handler for agent completion events.
|
||||||
|
func NewCompletionHandler(client *forge.Client) *CompletionHandler {
|
||||||
|
return &CompletionHandler{
|
||||||
|
forge: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the handler identifier.
|
||||||
|
func (h *CompletionHandler) Name() string {
|
||||||
|
return "completion"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match returns true if the signal indicates an agent has finished a task.
|
||||||
|
func (h *CompletionHandler) Match(signal *jobrunner.PipelineSignal) bool {
|
||||||
|
return signal.Type == "agent_completion"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute updates the issue labels based on the completion status.
|
||||||
|
func (h *CompletionHandler) Execute(ctx context.Context, signal *jobrunner.PipelineSignal) (*jobrunner.ActionResult, error) {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Remove in-progress label.
|
||||||
|
if inProgressLabel, err := h.forge.GetLabelByName(signal.RepoOwner, signal.RepoName, LabelInProgress); err == nil {
|
||||||
|
_ = h.forge.RemoveIssueLabel(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), inProgressLabel.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if signal.Success {
|
||||||
|
completeLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentComplete, ColorAgentComplete)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ensure label %s: %w", LabelAgentComplete, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{completeLabel.ID}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add completed label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if signal.Message != "" {
|
||||||
|
_ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), signal.Message)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
failedLabel, err := h.forge.EnsureLabel(signal.RepoOwner, signal.RepoName, LabelAgentFailed, ColorAgentFailed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("ensure label %s: %w", LabelAgentFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.forge.AddIssueLabels(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), []int64{failedLabel.ID}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add failed label: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := "Agent reported failure."
|
||||||
|
if signal.Error != "" {
|
||||||
|
msg += fmt.Sprintf("\n\nError: %s", signal.Error)
|
||||||
|
}
|
||||||
|
_ = h.forge.CreateIssueComment(signal.RepoOwner, signal.RepoName, int64(signal.ChildNumber), msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &jobrunner.ActionResult{
|
||||||
|
Action: "completion",
|
||||||
|
RepoOwner: signal.RepoOwner,
|
||||||
|
RepoName: signal.RepoName,
|
||||||
|
EpicNumber: signal.EpicNumber,
|
||||||
|
ChildNumber: signal.ChildNumber,
|
||||||
|
Success: true,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Duration: time.Since(start),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
382
pkg/ratelimit/ratelimit.go
Normal file
382
pkg/ratelimit/ratelimit.go
Normal file
|
|
@ -0,0 +1,382 @@
|
||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelQuota defines the rate limits for a specific model.
|
||||||
|
type ModelQuota struct {
|
||||||
|
MaxRPM int `yaml:"max_rpm"` // Requests per minute
|
||||||
|
MaxTPM int `yaml:"max_tpm"` // Tokens per minute
|
||||||
|
MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenEntry records a token usage event.
|
||||||
|
type TokenEntry struct {
|
||||||
|
Time time.Time `yaml:"time"`
|
||||||
|
Count int `yaml:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UsageStats tracks usage history for a model.
|
||||||
|
type UsageStats struct {
|
||||||
|
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
|
||||||
|
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
|
||||||
|
DayStart time.Time `yaml:"day_start"`
|
||||||
|
DayCount int `yaml:"day_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimiter manages rate limits across multiple models.
|
||||||
|
type RateLimiter struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
Quotas map[string]ModelQuota `yaml:"quotas"`
|
||||||
|
State map[string]*UsageStats `yaml:"state"`
|
||||||
|
filePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new RateLimiter with default quotas.
|
||||||
|
func New() (*RateLimiter, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rl := &RateLimiter{
|
||||||
|
Quotas: make(map[string]ModelQuota),
|
||||||
|
State: make(map[string]*UsageStats),
|
||||||
|
filePath: filepath.Join(home, ".core", "ratelimits.yaml"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default quotas based on Tier 1 observations (Feb 2026)
|
||||||
|
rl.Quotas["gemini-3-pro-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||||
|
rl.Quotas["gemini-3-flash-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||||
|
rl.Quotas["gemini-2.5-pro"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||||
|
rl.Quotas["gemini-2.0-flash"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0} // Unlimited RPD
|
||||||
|
rl.Quotas["gemini-2.0-flash-lite"] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} // Unlimited
|
||||||
|
|
||||||
|
return rl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads the state from disk.
|
||||||
|
func (rl *RateLimiter) Load() error {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(rl.filePath)
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return yaml.Unmarshal(data, rl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist writes the state to disk.
|
||||||
|
func (rl *RateLimiter) Persist() error {
|
||||||
|
rl.mu.RLock()
|
||||||
|
defer rl.mu.RUnlock()
|
||||||
|
|
||||||
|
data, err := yaml.Marshal(rl)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := filepath.Dir(rl.filePath)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.WriteFile(rl.filePath, data, 0644)
|
||||||
|
}
|
||||||
|
|
||||||
|
// prune removes entries older than the sliding window (1 minute).
|
||||||
|
// Caller must hold lock.
|
||||||
|
func (rl *RateLimiter) prune(model string) {
|
||||||
|
stats, ok := rl.State[model]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
window := now.Add(-1 * time.Minute)
|
||||||
|
|
||||||
|
// Prune requests
|
||||||
|
validReqs := 0
|
||||||
|
for _, t := range stats.Requests {
|
||||||
|
if t.After(window) {
|
||||||
|
stats.Requests[validReqs] = t
|
||||||
|
validReqs++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stats.Requests = stats.Requests[:validReqs]
|
||||||
|
|
||||||
|
// Prune tokens
|
||||||
|
validTokens := 0
|
||||||
|
for _, t := range stats.Tokens {
|
||||||
|
if t.Time.After(window) {
|
||||||
|
stats.Tokens[validTokens] = t
|
||||||
|
validTokens++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stats.Tokens = stats.Tokens[:validTokens]
|
||||||
|
|
||||||
|
// Reset daily counter if day has passed
|
||||||
|
if now.Sub(stats.DayStart) >= 24*time.Hour {
|
||||||
|
stats.DayStart = now
|
||||||
|
stats.DayCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanSend checks if a request can be sent without violating limits.
|
||||||
|
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
quota, ok := rl.Quotas[model]
|
||||||
|
if !ok {
|
||||||
|
return true // Unknown models are allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlimited check
|
||||||
|
if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure state exists
|
||||||
|
if _, ok := rl.State[model]; !ok {
|
||||||
|
rl.State[model] = &UsageStats{
|
||||||
|
DayStart: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.prune(model)
|
||||||
|
stats := rl.State[model]
|
||||||
|
|
||||||
|
// Check RPD
|
||||||
|
if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check RPM
|
||||||
|
if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check TPM
|
||||||
|
if quota.MaxTPM > 0 {
|
||||||
|
currentTokens := 0
|
||||||
|
for _, t := range stats.Tokens {
|
||||||
|
currentTokens += t.Count
|
||||||
|
}
|
||||||
|
if currentTokens+estimatedTokens > quota.MaxTPM {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordUsage records a successful API call.
|
||||||
|
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := rl.State[model]; !ok {
|
||||||
|
rl.State[model] = &UsageStats{
|
||||||
|
DayStart: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := rl.State[model]
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
stats.Requests = append(stats.Requests, now)
|
||||||
|
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens})
|
||||||
|
stats.DayCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||||
|
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if rl.CanSend(model, tokens) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
// check again
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears stats for a model (or all if model is empty).
|
||||||
|
func (rl *RateLimiter) Reset(model string) {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
if model == "" {
|
||||||
|
rl.State = make(map[string]*UsageStats)
|
||||||
|
} else {
|
||||||
|
delete(rl.State, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelStats represents a snapshot of usage.
|
||||||
|
type ModelStats struct {
|
||||||
|
RPM int
|
||||||
|
MaxRPM int
|
||||||
|
TPM int
|
||||||
|
MaxTPM int
|
||||||
|
RPD int
|
||||||
|
MaxRPD int
|
||||||
|
DayStart time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns current stats for a model.
|
||||||
|
func (rl *RateLimiter) Stats(model string) ModelStats {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
rl.prune(model)
|
||||||
|
|
||||||
|
stats := ModelStats{}
|
||||||
|
quota, ok := rl.Quotas[model]
|
||||||
|
if ok {
|
||||||
|
stats.MaxRPM = quota.MaxRPM
|
||||||
|
stats.MaxTPM = quota.MaxTPM
|
||||||
|
stats.MaxRPD = quota.MaxRPD
|
||||||
|
}
|
||||||
|
|
||||||
|
if s, ok := rl.State[model]; ok {
|
||||||
|
stats.RPM = len(s.Requests)
|
||||||
|
stats.RPD = s.DayCount
|
||||||
|
stats.DayStart = s.DayStart
|
||||||
|
for _, t := range s.Tokens {
|
||||||
|
stats.TPM += t.Count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllStats returns stats for all tracked models.
|
||||||
|
func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||||
|
rl.mu.Lock()
|
||||||
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
result := make(map[string]ModelStats)
|
||||||
|
|
||||||
|
// Collect all model names
|
||||||
|
for m := range rl.Quotas {
|
||||||
|
result[m] = ModelStats{}
|
||||||
|
}
|
||||||
|
for m := range rl.State {
|
||||||
|
result[m] = ModelStats{}
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
window := now.Add(-1 * time.Minute)
|
||||||
|
|
||||||
|
for m := range result {
|
||||||
|
// Prune inline
|
||||||
|
if s, ok := rl.State[m]; ok {
|
||||||
|
validReqs := 0
|
||||||
|
for _, t := range s.Requests {
|
||||||
|
if t.After(window) {
|
||||||
|
s.Requests[validReqs] = t
|
||||||
|
validReqs++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Requests = s.Requests[:validReqs]
|
||||||
|
|
||||||
|
validTokens := 0
|
||||||
|
for _, t := range s.Tokens {
|
||||||
|
if t.Time.After(window) {
|
||||||
|
s.Tokens[validTokens] = t
|
||||||
|
validTokens++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.Tokens = s.Tokens[:validTokens]
|
||||||
|
|
||||||
|
if now.Sub(s.DayStart) >= 24*time.Hour {
|
||||||
|
s.DayStart = now
|
||||||
|
s.DayCount = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ms := ModelStats{}
|
||||||
|
if q, ok := rl.Quotas[m]; ok {
|
||||||
|
ms.MaxRPM = q.MaxRPM
|
||||||
|
ms.MaxTPM = q.MaxTPM
|
||||||
|
ms.MaxRPD = q.MaxRPD
|
||||||
|
}
|
||||||
|
if s, ok := rl.State[m]; ok {
|
||||||
|
ms.RPM = len(s.Requests)
|
||||||
|
ms.RPD = s.DayCount
|
||||||
|
ms.DayStart = s.DayStart
|
||||||
|
for _, t := range s.Tokens {
|
||||||
|
ms.TPM += t.Count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result[m] = ms
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens calls the Google API to count tokens for a prompt.
|
||||||
|
func CountTokens(apiKey, model, text string) (int, error) {
|
||||||
|
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens?key=%s", model, apiKey)
|
||||||
|
|
||||||
|
reqBody := map[string]any{
|
||||||
|
"contents": []any{
|
||||||
|
map[string]any{
|
||||||
|
"parts": []any{
|
||||||
|
map[string]string{"text": text},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
return 0, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var result struct {
|
||||||
|
TotalTokens int `json:"totalTokens"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result.TotalTokens, nil
|
||||||
|
}
|
||||||
176
pkg/ratelimit/ratelimit_test.go
Normal file
176
pkg/ratelimit/ratelimit_test.go
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCanSend_Good(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml")
|
||||||
|
|
||||||
|
model := "test-model"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100}
|
||||||
|
|
||||||
|
if !rl.CanSend(model, 100) {
|
||||||
|
t.Errorf("Expected CanSend to return true for fresh state")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanSend_RPMExceeded_Bad(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "test-rpm"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100}
|
||||||
|
|
||||||
|
rl.RecordUsage(model, 10, 10)
|
||||||
|
rl.RecordUsage(model, 10, 10)
|
||||||
|
|
||||||
|
if rl.CanSend(model, 10) {
|
||||||
|
t.Errorf("Expected CanSend to return false after exceeding RPM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanSend_TPMExceeded_Bad(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "test-tpm"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100}
|
||||||
|
|
||||||
|
rl.RecordUsage(model, 50, 40) // 90 tokens used
|
||||||
|
|
||||||
|
if rl.CanSend(model, 20) { // 90 + 20 = 110 > 100
|
||||||
|
t.Errorf("Expected CanSend to return false when estimated tokens exceed TPM")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanSend_RPDExceeded_Bad(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "test-rpd"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2}
|
||||||
|
|
||||||
|
rl.RecordUsage(model, 10, 10)
|
||||||
|
rl.RecordUsage(model, 10, 10)
|
||||||
|
|
||||||
|
if rl.CanSend(model, 10) {
|
||||||
|
t.Errorf("Expected CanSend to return false after exceeding RPD")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCanSend_UnlimitedModel_Good(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "test-unlimited"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0}
|
||||||
|
|
||||||
|
// Should always be allowed
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
rl.RecordUsage(model, 100, 100)
|
||||||
|
}
|
||||||
|
if !rl.CanSend(model, 999999) {
|
||||||
|
t.Errorf("Expected unlimited model to always allow sends")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecordUsage_PrunesOldEntries_Good(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "test-prune"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100}
|
||||||
|
|
||||||
|
// Manually inject old data
|
||||||
|
oldTime := time.Now().Add(-2 * time.Minute)
|
||||||
|
rl.State[model] = &UsageStats{
|
||||||
|
Requests: []time.Time{oldTime, oldTime, oldTime},
|
||||||
|
Tokens: []TokenEntry{
|
||||||
|
{Time: oldTime, Count: 100},
|
||||||
|
{Time: oldTime, Count: 100},
|
||||||
|
},
|
||||||
|
DayStart: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// CanSend triggers prune
|
||||||
|
if !rl.CanSend(model, 10) {
|
||||||
|
t.Errorf("Expected CanSend to return true after pruning old entries")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := rl.State[model]
|
||||||
|
if len(stats.Requests) != 0 {
|
||||||
|
t.Errorf("Expected 0 requests after pruning old entries, got %d", len(stats.Requests))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistAndLoad_Good(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||||
|
|
||||||
|
rl1, _ := New()
|
||||||
|
rl1.filePath = path
|
||||||
|
model := "persist-test"
|
||||||
|
rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500}
|
||||||
|
rl1.RecordUsage(model, 100, 100)
|
||||||
|
|
||||||
|
if err := rl1.Persist(); err != nil {
|
||||||
|
t.Fatalf("Persist failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rl2, _ := New()
|
||||||
|
rl2.filePath = path
|
||||||
|
if err := rl2.Load(); err != nil {
|
||||||
|
t.Fatalf("Load failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := rl2.Stats(model)
|
||||||
|
if stats.RPM != 1 {
|
||||||
|
t.Errorf("Expected RPM 1 after load, got %d", stats.RPM)
|
||||||
|
}
|
||||||
|
if stats.TPM != 200 {
|
||||||
|
t.Errorf("Expected TPM 200 after load, got %d", stats.TPM)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWaitForCapacity_Ugly(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
model := "wait-test"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100}
|
||||||
|
|
||||||
|
rl.RecordUsage(model, 10, 10) // Use up the 1 RPM
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := rl.WaitForCapacity(ctx, model, 10)
|
||||||
|
if err != context.DeadlineExceeded {
|
||||||
|
t.Errorf("Expected DeadlineExceeded, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultQuotas_Good(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
expected := []string{
|
||||||
|
"gemini-3-pro-preview",
|
||||||
|
"gemini-3-flash-preview",
|
||||||
|
"gemini-2.0-flash",
|
||||||
|
}
|
||||||
|
for _, m := range expected {
|
||||||
|
if _, ok := rl.Quotas[m]; !ok {
|
||||||
|
t.Errorf("Expected default quota for %s", m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllStats_Good(t *testing.T) {
|
||||||
|
rl, _ := New()
|
||||||
|
rl.RecordUsage("gemini-3-pro-preview", 1000, 500)
|
||||||
|
|
||||||
|
all := rl.AllStats()
|
||||||
|
if len(all) < 5 {
|
||||||
|
t.Errorf("Expected at least 5 models in AllStats, got %d", len(all))
|
||||||
|
}
|
||||||
|
|
||||||
|
pro := all["gemini-3-pro-preview"]
|
||||||
|
if pro.RPM != 1 {
|
||||||
|
t.Errorf("Expected RPM 1 for pro, got %d", pro.RPM)
|
||||||
|
}
|
||||||
|
if pro.TPM != 1500 {
|
||||||
|
t.Errorf("Expected TPM 1500 for pro, got %d", pro.TPM)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue