commit fa1a6fc0307ff82f7ef55c212283f44f04b18ec7 Author: Snider Date: Thu Feb 19 16:09:13 2026 +0000 feat: extract go-ratelimit from core/go pkg/ratelimit Token counting, model quotas, sliding window rate limiter. Zero external dependencies (stdlib only). Module: forge.lthn.ai/core/go-ratelimit Co-Authored-By: Virgil diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..a3bff6a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,19 @@ +# CLAUDE.md + +## What This Is + +Token counting, model quotas, and sliding window rate limiter. Module: `forge.lthn.ai/core/go-ratelimit` + +## Commands + +```bash +go test ./... # Run all tests +go test -v -run Name # Run single test +``` + +## Coding Standards + +- UK English +- `go test ./...` must pass before commit +- Conventional commits: `type(scope): description` +- Co-Author: `Co-Authored-By: Virgil ` diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6568819 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module forge.lthn.ai/core/go-ratelimit + +go 1.25.5 diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..bb51d49 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,389 @@ +package ratelimit + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sync" + "time" + + "gopkg.in/yaml.v3" +) + +// ModelQuota defines the rate limits for a specific model. +type ModelQuota struct { + MaxRPM int `yaml:"max_rpm"` // Requests per minute + MaxTPM int `yaml:"max_tpm"` // Tokens per minute + MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited) +} + +// TokenEntry records a token usage event. +type TokenEntry struct { + Time time.Time `yaml:"time"` + Count int `yaml:"count"` +} + +// UsageStats tracks usage history for a model. +type UsageStats struct { + Requests []time.Time `yaml:"requests"` // Sliding window (1m) + Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m) + DayStart time.Time `yaml:"day_start"` + DayCount int `yaml:"day_count"` +} + +// RateLimiter manages rate limits across multiple models. +type RateLimiter struct { + mu sync.RWMutex + Quotas map[string]ModelQuota `yaml:"quotas"` + State map[string]*UsageStats `yaml:"state"` + filePath string +} + +// New creates a new RateLimiter with default quotas. +func New() (*RateLimiter, error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + + rl := &RateLimiter{ + Quotas: make(map[string]ModelQuota), + State: make(map[string]*UsageStats), + filePath: filepath.Join(home, ".core", "ratelimits.yaml"), + } + + // Default quotas based on Tier 1 observations (Feb 2026) + rl.Quotas["gemini-3-pro-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-3-flash-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-2.5-pro"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} + rl.Quotas["gemini-2.0-flash"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0} // Unlimited RPD + rl.Quotas["gemini-2.0-flash-lite"] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} // Unlimited + + return rl, nil +} + +// Load reads the state from disk. +func (rl *RateLimiter) Load() error { + rl.mu.Lock() + defer rl.mu.Unlock() + + data, err := os.ReadFile(rl.filePath) + if os.IsNotExist(err) { + return nil + } + if err != nil { + return err + } + + return yaml.Unmarshal(data, rl) +} + +// Persist writes the state to disk. +func (rl *RateLimiter) Persist() error { + rl.mu.RLock() + defer rl.mu.RUnlock() + + data, err := yaml.Marshal(rl) + if err != nil { + return err + } + + dir := filepath.Dir(rl.filePath) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + return os.WriteFile(rl.filePath, data, 0644) +} + +// prune removes entries older than the sliding window (1 minute). +// Caller must hold lock. +func (rl *RateLimiter) prune(model string) { + stats, ok := rl.State[model] + if !ok { + return + } + + now := time.Now() + window := now.Add(-1 * time.Minute) + + // Prune requests + validReqs := 0 + for _, t := range stats.Requests { + if t.After(window) { + stats.Requests[validReqs] = t + validReqs++ + } + } + stats.Requests = stats.Requests[:validReqs] + + // Prune tokens + validTokens := 0 + for _, t := range stats.Tokens { + if t.Time.After(window) { + stats.Tokens[validTokens] = t + validTokens++ + } + } + stats.Tokens = stats.Tokens[:validTokens] + + // Reset daily counter if day has passed + if now.Sub(stats.DayStart) >= 24*time.Hour { + stats.DayStart = now + stats.DayCount = 0 + } +} + +// CanSend checks if a request can be sent without violating limits. +func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { + rl.mu.Lock() + defer rl.mu.Unlock() + + quota, ok := rl.Quotas[model] + if !ok { + return true // Unknown models are allowed + } + + // Unlimited check + if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 { + return true + } + + // Ensure state exists + if _, ok := rl.State[model]; !ok { + rl.State[model] = &UsageStats{ + DayStart: time.Now(), + } + } + + rl.prune(model) + stats := rl.State[model] + + // Check RPD + if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD { + return false + } + + // Check RPM + if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM { + return false + } + + // Check TPM + if quota.MaxTPM > 0 { + currentTokens := 0 + for _, t := range stats.Tokens { + currentTokens += t.Count + } + if currentTokens+estimatedTokens > quota.MaxTPM { + return false + } + } + + return true +} + +// RecordUsage records a successful API call. +func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) { + rl.mu.Lock() + defer rl.mu.Unlock() + + if _, ok := rl.State[model]; !ok { + rl.State[model] = &UsageStats{ + DayStart: time.Now(), + } + } + + stats := rl.State[model] + now := time.Now() + + stats.Requests = append(stats.Requests, now) + stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) + stats.DayCount++ +} + +// WaitForCapacity blocks until capacity is available or context is cancelled. +func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + if rl.CanSend(model, tokens) { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + // check again + } + } +} + +// Reset clears stats for a model (or all if model is empty). +func (rl *RateLimiter) Reset(model string) { + rl.mu.Lock() + defer rl.mu.Unlock() + + if model == "" { + rl.State = make(map[string]*UsageStats) + } else { + delete(rl.State, model) + } +} + +// ModelStats represents a snapshot of usage. +type ModelStats struct { + RPM int + MaxRPM int + TPM int + MaxTPM int + RPD int + MaxRPD int + DayStart time.Time +} + +// Stats returns current stats for a model. +func (rl *RateLimiter) Stats(model string) ModelStats { + rl.mu.Lock() + defer rl.mu.Unlock() + + rl.prune(model) + + stats := ModelStats{} + quota, ok := rl.Quotas[model] + if ok { + stats.MaxRPM = quota.MaxRPM + stats.MaxTPM = quota.MaxTPM + stats.MaxRPD = quota.MaxRPD + } + + if s, ok := rl.State[model]; ok { + stats.RPM = len(s.Requests) + stats.RPD = s.DayCount + stats.DayStart = s.DayStart + for _, t := range s.Tokens { + stats.TPM += t.Count + } + } + + return stats +} + +// AllStats returns stats for all tracked models. +func (rl *RateLimiter) AllStats() map[string]ModelStats { + rl.mu.Lock() + defer rl.mu.Unlock() + + result := make(map[string]ModelStats) + + // Collect all model names + for m := range rl.Quotas { + result[m] = ModelStats{} + } + for m := range rl.State { + result[m] = ModelStats{} + } + + now := time.Now() + window := now.Add(-1 * time.Minute) + + for m := range result { + // Prune inline + if s, ok := rl.State[m]; ok { + validReqs := 0 + for _, t := range s.Requests { + if t.After(window) { + s.Requests[validReqs] = t + validReqs++ + } + } + s.Requests = s.Requests[:validReqs] + + validTokens := 0 + for _, t := range s.Tokens { + if t.Time.After(window) { + s.Tokens[validTokens] = t + validTokens++ + } + } + s.Tokens = s.Tokens[:validTokens] + + if now.Sub(s.DayStart) >= 24*time.Hour { + s.DayStart = now + s.DayCount = 0 + } + } + + ms := ModelStats{} + if q, ok := rl.Quotas[m]; ok { + ms.MaxRPM = q.MaxRPM + ms.MaxTPM = q.MaxTPM + ms.MaxRPD = q.MaxRPD + } + if s, ok := rl.State[m]; ok { + ms.RPM = len(s.Requests) + ms.RPD = s.DayCount + ms.DayStart = s.DayStart + for _, t := range s.Tokens { + ms.TPM += t.Count + } + } + result[m] = ms + } + + return result +} + +// CountTokens calls the Google API to count tokens for a prompt. +func CountTokens(apiKey, model, text string) (int, error) { + url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model) + + reqBody := map[string]any{ + "contents": []any{ + map[string]any{ + "parts": []any{ + map[string]string{"text": text}, + }, + }, + }, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return 0, err + } + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) + if err != nil { + return 0, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-goog-api-key", apiKey) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return 0, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return 0, err + } + + return result.TotalTokens, nil +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..1247960 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,176 @@ +package ratelimit + +import ( + "context" + "path/filepath" + "testing" + "time" +) + +func TestCanSend_Good(t *testing.T) { + rl, _ := New() + rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml") + + model := "test-model" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100} + + if !rl.CanSend(model, 100) { + t.Errorf("Expected CanSend to return true for fresh state") + } +} + +func TestCanSend_RPMExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-rpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + + if rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return false after exceeding RPM") + } +} + +func TestCanSend_TPMExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-tpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100} + + rl.RecordUsage(model, 50, 40) // 90 tokens used + + if rl.CanSend(model, 20) { // 90 + 20 = 110 > 100 + t.Errorf("Expected CanSend to return false when estimated tokens exceed TPM") + } +} + +func TestCanSend_RPDExceeded_Bad(t *testing.T) { + rl, _ := New() + model := "test-rpd" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + + if rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return false after exceeding RPD") + } +} + +func TestCanSend_UnlimitedModel_Good(t *testing.T) { + rl, _ := New() + model := "test-unlimited" + rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} + + // Should always be allowed + for i := 0; i < 1000; i++ { + rl.RecordUsage(model, 100, 100) + } + if !rl.CanSend(model, 999999) { + t.Errorf("Expected unlimited model to always allow sends") + } +} + +func TestRecordUsage_PrunesOldEntries_Good(t *testing.T) { + rl, _ := New() + model := "test-prune" + rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100} + + // Manually inject old data + oldTime := time.Now().Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{oldTime, oldTime, oldTime}, + Tokens: []TokenEntry{ + {Time: oldTime, Count: 100}, + {Time: oldTime, Count: 100}, + }, + DayStart: time.Now(), + } + + // CanSend triggers prune + if !rl.CanSend(model, 10) { + t.Errorf("Expected CanSend to return true after pruning old entries") + } + + stats := rl.State[model] + if len(stats.Requests) != 0 { + t.Errorf("Expected 0 requests after pruning old entries, got %d", len(stats.Requests)) + } +} + +func TestPersistAndLoad_Good(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ratelimits.yaml") + + rl1, _ := New() + rl1.filePath = path + model := "persist-test" + rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500} + rl1.RecordUsage(model, 100, 100) + + if err := rl1.Persist(); err != nil { + t.Fatalf("Persist failed: %v", err) + } + + rl2, _ := New() + rl2.filePath = path + if err := rl2.Load(); err != nil { + t.Fatalf("Load failed: %v", err) + } + + stats := rl2.Stats(model) + if stats.RPM != 1 { + t.Errorf("Expected RPM 1 after load, got %d", stats.RPM) + } + if stats.TPM != 200 { + t.Errorf("Expected TPM 200 after load, got %d", stats.TPM) + } +} + +func TestWaitForCapacity_Ugly(t *testing.T) { + rl, _ := New() + model := "wait-test" + rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) // Use up the 1 RPM + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := rl.WaitForCapacity(ctx, model, 10) + if err != context.DeadlineExceeded { + t.Errorf("Expected DeadlineExceeded, got %v", err) + } +} + +func TestDefaultQuotas_Good(t *testing.T) { + rl, _ := New() + expected := []string{ + "gemini-3-pro-preview", + "gemini-3-flash-preview", + "gemini-2.0-flash", + } + for _, m := range expected { + if _, ok := rl.Quotas[m]; !ok { + t.Errorf("Expected default quota for %s", m) + } + } +} + +func TestAllStats_Good(t *testing.T) { + rl, _ := New() + rl.RecordUsage("gemini-3-pro-preview", 1000, 500) + + all := rl.AllStats() + if len(all) < 5 { + t.Errorf("Expected at least 5 models in AllStats, got %d", len(all)) + } + + pro := all["gemini-3-pro-preview"] + if pro.RPM != 1 { + t.Errorf("Expected RPM 1 for pro, got %d", pro.RPM) + } + if pro.TPM != 1500 { + t.Errorf("Expected TPM 1500 for pro, got %d", pro.TPM) + } +}