diff --git a/FINDINGS.md b/FINDINGS.md index 9522692..0c5d194 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -21,3 +21,86 @@ Extracted from `forge.lthn.ai/core/go` on 19 Feb 2026. ### Tests - 1 test file covering sliding window and quota enforcement + +--- + +## 2026-02-20: Phase 0 -- Hardening (Charon) + +### Coverage: 77.1% -> 95.1% + +Rewrote test suite with testify assert/require. Table-driven subtests throughout. + +#### Tests added + +- **CanSend boundaries**: exact RPM/TPM/RPD limits, RPM-only, TPM-only, zero-token estimates, unknown models, unlimited models +- **Prune**: keeps recent entries, prunes old ones, daily reset at 24h, boundary-exact timestamps, noop on non-existent model +- **RecordUsage**: fresh state, accumulation, existing state +- **Reset**: single model, all models (empty string), non-existent model +- **WaitForCapacity**: immediate capacity, context cancellation, pre-cancelled context, unknown model +- **Stats/AllStats**: known/unknown/quota-only models, pruning in AllStats, daily reset in AllStats +- **Persist/Load**: round-trip, non-existent file, corrupt YAML, unreadable file, nested directory creation, unwritable directory +- **Concurrency**: 20 goroutines x 50 ops (CanSend + RecordUsage + Stats), concurrent Reset + RecordUsage + AllStats +- **Benchmarks**: BenchmarkCanSend (1000-entry window), BenchmarkRecordUsage, BenchmarkCanSendConcurrent + +#### Remaining uncovered (5%) + +- `CountTokens` success path: hardcoded Google URL prevents unit testing without URL injection. Only the connection-error path is covered. +- `yaml.Marshal` error in `Persist()`: virtually impossible to trigger with valid structs. +- `os.UserHomeDir` error in `NewWithConfig()`: only fails when `$HOME` is unset. + +### Race detector + +`go test -race ./...` passes clean. The `sync.RWMutex` correctly guards all shared state. + +### go vet + +No warnings. + +--- + +## 2026-02-20: Phase 1 -- Generalisation (Charon) + +### Problem + +Hardcoded Gemini-specific quotas in `New()`. No way to configure for other providers. + +### Solution + +Introduced provider-agnostic configuration without breaking existing API. + +#### New types + +- `Provider` -- string type with constants: `ProviderGemini`, `ProviderOpenAI`, `ProviderAnthropic`, `ProviderLocal` +- `ProviderProfile` -- bundles provider identity with model quotas map +- `Config` -- construction config with `FilePath`, `Providers` list, `Quotas` map + +#### New functions + +- `DefaultProfiles()` -- returns pre-configured profiles for all four providers +- `NewWithConfig(Config)` -- creates limiter from explicit configuration +- `SetQuota(model, quota)` -- runtime quota modification +- `AddProvider(provider)` -- loads all default quotas for a provider at runtime + +#### Provider defaults (Feb 2026) + +| Provider | Models | RPM | TPM | RPD | +|----------|--------|-----|-----|-----| +| Gemini | gemini-3-pro-preview, gemini-3-flash-preview, gemini-2.5-pro | 150 | 1M | 1000 | +| Gemini | gemini-2.0-flash | 150 | 1M | unlimited | +| Gemini | gemini-2.0-flash-lite | unlimited | unlimited | unlimited | +| OpenAI | gpt-4o, gpt-4-turbo, o1 | 500 | 30K | unlimited | +| OpenAI | gpt-4o-mini, o1-mini, o3-mini | 500 | 200K | unlimited | +| Anthropic | claude-opus-4, claude-sonnet-4 | 50 | 40K | unlimited | +| Anthropic | claude-haiku-3.5 | 50 | 50K | unlimited | +| Local | (none by default) | -- | -- | -- | + +#### Backward compatibility + +`New()` delegates to `NewWithConfig(Config{Providers: []Provider{ProviderGemini}})`. Verified by `TestNewBackwardCompatibility` which asserts exact parity with the original hardcoded values. + +#### Design notes + +- Explicit quotas in `Config.Quotas` override provider defaults (merge-on-top pattern) +- Local provider has no default quotas -- users add per-model limits for hardware throttling +- `AddProvider()` is additive -- calling it does not remove existing quotas +- All new methods are mutex-protected and safe for concurrent use diff --git a/TODO.md b/TODO.md index 7705e86..156c5ca 100644 --- a/TODO.md +++ b/TODO.md @@ -1,4 +1,4 @@ -# TODO.md — go-ratelimit +# TODO.md -- go-ratelimit Dispatched from core/go orchestration. Pick up tasks in order. @@ -6,20 +6,23 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 0: Hardening & Test Coverage -- [ ] **Expand test coverage** — `ratelimit_test.go` exists. Add tests for: `CanSend()` at exact limits (RPM, TPM, RPD boundaries), `RecordUsage()` with concurrent goroutines (race test), `WaitForCapacity()` timeout behaviour, `prune()` sliding window edge cases, daily reset logic (cross-midnight), YAML persistence (save + reload state), empty/corrupt state file recovery. -- [ ] **Race condition test** — `go test -race ./...` with 10 goroutines calling `CanSend()` + `RecordUsage()` concurrently. The `sync.RWMutex` should handle it but verify. -- [ ] **Benchmark** — Add `BenchmarkCanSend` and `BenchmarkRecordUsage` with 1000 entries in sliding window. Measure prune() overhead. -- [ ] **`go vet ./...` clean** — Fix any warnings. +- [x] **Expand test coverage** -- `ratelimit_test.go` rewritten with testify. Tests for: `CanSend()` at exact limits (RPM, TPM, RPD boundaries), `RecordUsage()` with concurrent goroutines, `WaitForCapacity()` timeout and immediate-capacity paths, `prune()` sliding window edge cases, daily reset logic (24h boundary), YAML persistence (save + reload), corrupt/unreadable state file recovery, `Reset()` single/all/nonexistent, `Stats()` known/unknown/quota-only models, `AllStats()` with pruning and daily reset. +- [x] **Race condition test** -- `go test -race ./...` with 20 goroutines calling `CanSend()` + `RecordUsage()` + `Stats()` concurrently. Additional test with concurrent `Reset()` + `RecordUsage()` + `AllStats()`. All pass clean. +- [x] **Benchmark** -- `BenchmarkCanSend` (1000-entry window), `BenchmarkRecordUsage`, `BenchmarkCanSendConcurrent` (parallel). Measures prune() overhead. +- [x] **`go vet ./...` clean** -- No warnings. +- **Coverage: 95.1%** (up from 77.1%). Remaining uncovered: `CountTokens` success path (hardcoded Google URL), `yaml.Marshal` error path in `Persist()`, `os.UserHomeDir` error path in `NewWithConfig`. ## Phase 1: Generalise Beyond Gemini -- [ ] Hardcoded model quotas are Gemini-specific — abstract to provider-agnostic config -- [ ] Add quota profiles for OpenAI, Anthropic, and local (Ollama/MLX) backends -- [ ] Make default quotas configurable via YAML or environment variables +- [x] **Provider-agnostic config** -- Added `Provider` type, `ProviderProfile`, `Config` struct, `NewWithConfig()` constructor. Quotas are no longer hardcoded in `New()`. +- [x] **Quota profiles** -- `DefaultProfiles()` returns pre-configured profiles for Gemini, OpenAI (gpt-4o, o1, o3-mini), Anthropic (claude-opus-4, claude-sonnet-4, claude-haiku-3.5), and Local (empty, user-configurable). +- [x] **Configurable defaults** -- `Config` struct accepts `FilePath`, `Providers` list, and explicit `Quotas` map. Explicit quotas override provider defaults. YAML-serialisable. +- [x] **Backward compatibility** -- `New()` delegates to `NewWithConfig(Config{Providers: []Provider{ProviderGemini}})`. Existing API unchanged. Test `TestNewBackwardCompatibility` verifies exact parity. +- [x] **Runtime configuration** -- `SetQuota()` and `AddProvider()` allow modifying quotas after construction. Both are mutex-protected. ## Phase 2: Persistent State -- [ ] Currently stores state in YAML file — not safe for multi-process access +- [ ] Currently stores state in YAML file -- not safe for multi-process access - [ ] Consider SQLite for concurrent read/write safety (WAL mode) - [ ] Add state recovery on restart (reload sliding window from persisted data) diff --git a/go.mod b/go.mod index 335fe06..e7d567f 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,9 @@ module forge.lthn.ai/core/go-ratelimit go 1.25.5 require gopkg.in/yaml.v3 v3.0.1 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 +) diff --git a/go.sum b/go.sum index a62c313..c4c1710 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/ratelimit.go b/ratelimit.go index bb51d49..241a338 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -15,13 +15,48 @@ import ( "gopkg.in/yaml.v3" ) +// Provider identifies an LLM provider for quota profiles. +type Provider string + +const ( + // ProviderGemini is Google's Gemini family (default). + ProviderGemini Provider = "gemini" + // ProviderOpenAI is OpenAI's GPT/o-series family. + ProviderOpenAI Provider = "openai" + // ProviderAnthropic is Anthropic's Claude family. + ProviderAnthropic Provider = "anthropic" + // ProviderLocal is for local inference (Ollama, MLX, llama.cpp). + ProviderLocal Provider = "local" +) + // ModelQuota defines the rate limits for a specific model. type ModelQuota struct { - MaxRPM int `yaml:"max_rpm"` // Requests per minute - MaxTPM int `yaml:"max_tpm"` // Tokens per minute + MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited) + MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited) MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited) } +// ProviderProfile bundles model quotas for a provider. +type ProviderProfile struct { + Provider Provider `yaml:"provider"` + Models map[string]ModelQuota `yaml:"models"` +} + +// Config controls RateLimiter initialisation. +type Config struct { + // FilePath overrides the default state file location. + // If empty, defaults to ~/.core/ratelimits.yaml. + FilePath string `yaml:"file_path,omitempty"` + + // Quotas sets per-model rate limits directly. + // These are merged on top of any provider profile defaults. + Quotas map[string]ModelQuota `yaml:"quotas,omitempty"` + + // Providers lists provider profiles to load. + // If empty and Quotas is also empty, Gemini defaults are used. + Providers []Provider `yaml:"providers,omitempty"` +} + // TokenEntry records a token usage event. type TokenEntry struct { Time time.Time `yaml:"time"` @@ -44,29 +79,121 @@ type RateLimiter struct { filePath string } -// New creates a new RateLimiter with default quotas. +// DefaultProfiles returns pre-configured quota profiles for each provider. +// Values are based on published rate limits as of Feb 2026. +func DefaultProfiles() map[Provider]ProviderProfile { + return map[Provider]ProviderProfile{ + ProviderGemini: { + Provider: ProviderGemini, + Models: map[string]ModelQuota{ + "gemini-3-pro-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-3-flash-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-2.5-pro": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-2.0-flash": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0}, // Unlimited RPD + "gemini-2.0-flash-lite": {MaxRPM: 0, MaxTPM: 0, MaxRPD: 0}, // Unlimited + }, + }, + ProviderOpenAI: { + Provider: ProviderOpenAI, + Models: map[string]ModelQuota{ + "gpt-4o": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0}, + "gpt-4o-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0}, + "gpt-4-turbo": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0}, + "o1": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0}, + "o1-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0}, + "o3-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0}, + }, + }, + ProviderAnthropic: { + Provider: ProviderAnthropic, + Models: map[string]ModelQuota{ + "claude-opus-4": {MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0}, + "claude-sonnet-4": {MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0}, + "claude-haiku-3.5": {MaxRPM: 50, MaxTPM: 50000, MaxRPD: 0}, + }, + }, + ProviderLocal: { + Provider: ProviderLocal, + Models: map[string]ModelQuota{ + // Local inference has no external rate limits by default. + // Users can override per-model if their hardware requires throttling. + }, + }, + } +} + +// New creates a new RateLimiter with Gemini defaults. +// This preserves backward compatibility -- existing callers are unaffected. func New() (*RateLimiter, error) { - home, err := os.UserHomeDir() - if err != nil { - return nil, err + return NewWithConfig(Config{ + Providers: []Provider{ProviderGemini}, + }) +} + +// NewWithConfig creates a RateLimiter from explicit configuration. +// If no providers or quotas are specified, Gemini defaults are used. +func NewWithConfig(cfg Config) (*RateLimiter, error) { + filePath := cfg.FilePath + if filePath == "" { + home, err := os.UserHomeDir() + if err != nil { + return nil, err + } + filePath = filepath.Join(home, ".core", "ratelimits.yaml") } rl := &RateLimiter{ Quotas: make(map[string]ModelQuota), State: make(map[string]*UsageStats), - filePath: filepath.Join(home, ".core", "ratelimits.yaml"), + filePath: filePath, } - // Default quotas based on Tier 1 observations (Feb 2026) - rl.Quotas["gemini-3-pro-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} - rl.Quotas["gemini-3-flash-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} - rl.Quotas["gemini-2.5-pro"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000} - rl.Quotas["gemini-2.0-flash"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0} // Unlimited RPD - rl.Quotas["gemini-2.0-flash-lite"] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} // Unlimited + // Load provider profiles + profiles := DefaultProfiles() + providers := cfg.Providers + + // If nothing specified at all, default to Gemini + if len(providers) == 0 && len(cfg.Quotas) == 0 { + providers = []Provider{ProviderGemini} + } + + for _, p := range providers { + if profile, ok := profiles[p]; ok { + for model, quota := range profile.Models { + rl.Quotas[model] = quota + } + } + } + + // Merge explicit quotas on top (allows overrides) + for model, quota := range cfg.Quotas { + rl.Quotas[model] = quota + } return rl, nil } +// SetQuota sets or updates the quota for a specific model at runtime. +func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.Quotas[model] = quota +} + +// AddProvider loads all default quotas for a provider. +// Existing quotas for models in the profile are overwritten. +func (rl *RateLimiter) AddProvider(provider Provider) { + rl.mu.Lock() + defer rl.mu.Unlock() + + profiles := DefaultProfiles() + if profile, ok := profiles[provider]; ok { + for model, quota := range profile.Models { + rl.Quotas[model] = quota + } + } +} + // Load reads the state from disk. func (rl *RateLimiter) Load() error { rl.mu.Lock() diff --git a/ratelimit_test.go b/ratelimit_test.go index 1247960..80ae9b7 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -2,175 +2,1137 @@ package ratelimit import ( "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" "path/filepath" + "sync" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestCanSend_Good(t *testing.T) { - rl, _ := New() +// newTestLimiter returns a RateLimiter with file path set to a temp directory. +func newTestLimiter(t *testing.T) *RateLimiter { + t.Helper() + rl, err := New() + require.NoError(t, err) rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml") + return rl +} - model := "test-model" - rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100} +// --- Phase 0: CanSend boundary conditions --- - if !rl.CanSend(model, 100) { - t.Errorf("Expected CanSend to return true for fresh state") +func TestCanSend(t *testing.T) { + t.Run("fresh state allows send", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-model" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100} + + assert.True(t, rl.CanSend(model, 100), "fresh state should allow send") + }) + + t.Run("unknown model is always allowed", func(t *testing.T) { + rl := newTestLimiter(t) + assert.True(t, rl.CanSend("unknown-model", 999999), "unknown model should be allowed") + }) + + t.Run("unlimited model is always allowed", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-unlimited" + rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} + + for i := 0; i < 1000; i++ { + rl.RecordUsage(model, 100, 100) + } + assert.True(t, rl.CanSend(model, 999999), "unlimited model should always allow sends") + }) + + t.Run("RPM at exact limit is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-rpm-boundary" + rl.Quotas[model] = ModelQuota{MaxRPM: 3, MaxTPM: 1000000, MaxRPD: 1000} + + rl.RecordUsage(model, 1, 1) + rl.RecordUsage(model, 1, 1) + assert.True(t, rl.CanSend(model, 1), "should allow at RPM-1") + + rl.RecordUsage(model, 1, 1) + assert.False(t, rl.CanSend(model, 1), "should reject at exact RPM limit") + }) + + t.Run("RPM exceeded is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-rpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + assert.False(t, rl.CanSend(model, 10), "should reject after exceeding RPM") + }) + + t.Run("TPM at exact limit is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-tpm-boundary" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 1000} + + rl.RecordUsage(model, 50, 40) // 90 tokens + assert.True(t, rl.CanSend(model, 10), "should allow at exact TPM limit (90+10=100)") + assert.False(t, rl.CanSend(model, 11), "should reject when exceeding TPM by 1 (90+11=101)") + }) + + t.Run("TPM exceeded is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-tpm" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100} + + rl.RecordUsage(model, 50, 40) // 90 tokens + assert.False(t, rl.CanSend(model, 20), "should reject when 90+20=110 > 100") + }) + + t.Run("RPD at exact limit is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-rpd-boundary" + rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 3} + + rl.RecordUsage(model, 1, 1) + rl.RecordUsage(model, 1, 1) + assert.True(t, rl.CanSend(model, 1), "should allow at RPD-1") + + rl.RecordUsage(model, 1, 1) + assert.False(t, rl.CanSend(model, 1), "should reject at exact RPD limit") + }) + + t.Run("RPD exceeded is rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-rpd" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2} + + rl.RecordUsage(model, 10, 10) + rl.RecordUsage(model, 10, 10) + assert.False(t, rl.CanSend(model, 10), "should reject after exceeding RPD") + }) + + t.Run("RPD unlimited allows infinite requests", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-rpd-unlimited" + rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 0} + + for i := 0; i < 100; i++ { + rl.RecordUsage(model, 1, 1) + } + assert.True(t, rl.CanSend(model, 1), "RPD=0 should mean unlimited daily requests") + }) + + t.Run("RPM only limit", func(t *testing.T) { + rl := newTestLimiter(t) + model := "rpm-only" + rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 0, MaxRPD: 0} + + rl.RecordUsage(model, 100000, 100000) + assert.True(t, rl.CanSend(model, 999999), "TPM=0 should be unlimited") + + rl.RecordUsage(model, 1, 1) + assert.False(t, rl.CanSend(model, 1), "RPM should still enforce") + }) + + t.Run("TPM only limit", func(t *testing.T) { + rl := newTestLimiter(t) + model := "tpm-only" + rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 100, MaxRPD: 0} + + // RPM=0 and RPD=0 means all-zero => unlimited shortcut + // Actually: MaxTPM > 0 but MaxRPM == 0 and MaxRPD == 0 means the + // unlimited check (all three == 0) does NOT trigger. + // Let's verify. + rl.RecordUsage(model, 50, 40) // 90 tokens + assert.True(t, rl.CanSend(model, 10), "should allow under TPM limit") + assert.False(t, rl.CanSend(model, 11), "should reject over TPM limit") + }) + + t.Run("zero estimated tokens always fits TPM", func(t *testing.T) { + rl := newTestLimiter(t) + model := "zero-est" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100} + + rl.RecordUsage(model, 50, 50) // exactly 100 tokens + assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit") + }) +} + +// --- Phase 0: Sliding window / prune tests --- + +func TestPrune(t *testing.T) { + t.Run("removes old entries", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-prune" + rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100} + + oldTime := time.Now().Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{oldTime, oldTime, oldTime}, + Tokens: []TokenEntry{ + {Time: oldTime, Count: 100}, + {Time: oldTime, Count: 100}, + }, + DayStart: time.Now(), + } + + assert.True(t, rl.CanSend(model, 10), "old entries should be pruned") + assert.Empty(t, rl.State[model].Requests, "requests should be empty after pruning") + assert.Empty(t, rl.State[model].Tokens, "tokens should be empty after pruning") + }) + + t.Run("keeps recent entries", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-keep" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 100} + + now := time.Now() + recent := now.Add(-30 * time.Second) + old := now.Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{old, recent, old, recent}, + Tokens: []TokenEntry{ + {Time: old, Count: 100}, + {Time: recent, Count: 200}, + {Time: old, Count: 300}, + }, + DayStart: now, + } + + rl.CanSend(model, 1) // triggers prune + + stats := rl.State[model] + assert.Len(t, stats.Requests, 2, "should keep 2 recent requests") + assert.Len(t, stats.Tokens, 1, "should keep 1 recent token entry") + assert.Equal(t, 200, stats.Tokens[0].Count, "kept token should be the recent one") + }) + + t.Run("daily reset when 24h elapsed", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-daily" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 1000000, MaxRPD: 5} + + // Set state with day started 25 hours ago and 5 daily requests used + rl.State[model] = &UsageStats{ + DayStart: time.Now().Add(-25 * time.Hour), + DayCount: 5, + } + + // Should reset daily counter and allow + assert.True(t, rl.CanSend(model, 1), "should allow after daily reset") + assert.Equal(t, 0, rl.State[model].DayCount, "day count should be reset") + }) + + t.Run("no daily reset within 24h", func(t *testing.T) { + rl := newTestLimiter(t) + model := "test-no-reset" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 1000000, MaxRPD: 5} + + rl.State[model] = &UsageStats{ + DayStart: time.Now().Add(-23 * time.Hour), + DayCount: 5, + } + + assert.False(t, rl.CanSend(model, 1), "should reject within 24h when RPD exhausted") + assert.Equal(t, 5, rl.State[model].DayCount, "day count should not be reset") + }) + + t.Run("prune on non-existent model is noop", func(t *testing.T) { + rl := newTestLimiter(t) + rl.mu.Lock() + rl.prune("nonexistent-model") + rl.mu.Unlock() + // Should not panic or create state + _, exists := rl.State["nonexistent-model"] + assert.False(t, exists, "prune should not create state for non-existent model") + }) + + t.Run("mixed old and new entries at boundary", func(t *testing.T) { + rl := newTestLimiter(t) + model := "boundary-prune" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 1000000, MaxRPD: 1000} + + now := time.Now() + // Entry exactly at the 1-minute boundary (should be pruned because After is strict) + atBoundary := now.Add(-1 * time.Minute) + justInside := now.Add(-59 * time.Second) + + rl.State[model] = &UsageStats{ + Requests: []time.Time{atBoundary, justInside}, + Tokens: []TokenEntry{ + {Time: atBoundary, Count: 500}, + {Time: justInside, Count: 300}, + }, + DayStart: now, + } + + rl.CanSend(model, 1) + stats := rl.State[model] + assert.Len(t, stats.Requests, 1, "entry at exact boundary should be pruned") + assert.Len(t, stats.Tokens, 1, "token at exact boundary should be pruned") + }) +} + +// --- Phase 0: RecordUsage --- + +func TestRecordUsage(t *testing.T) { + t.Run("records into fresh state", func(t *testing.T) { + rl := newTestLimiter(t) + model := "record-fresh" + + rl.RecordUsage(model, 100, 50) + + stats := rl.State[model] + require.NotNil(t, stats) + assert.Len(t, stats.Requests, 1) + assert.Len(t, stats.Tokens, 1) + assert.Equal(t, 150, stats.Tokens[0].Count) + assert.Equal(t, 1, stats.DayCount) + }) + + t.Run("accumulates multiple recordings", func(t *testing.T) { + rl := newTestLimiter(t) + model := "record-multi" + + rl.RecordUsage(model, 10, 20) + rl.RecordUsage(model, 30, 40) + rl.RecordUsage(model, 50, 60) + + stats := rl.State[model] + assert.Len(t, stats.Requests, 3) + assert.Len(t, stats.Tokens, 3) + assert.Equal(t, 3, stats.DayCount) + + // Verify total tokens + total := 0 + for _, te := range stats.Tokens { + total += te.Count + } + assert.Equal(t, 210, total, "total tokens should be 10+20+30+40+50+60=210") + }) + + t.Run("records into existing state", func(t *testing.T) { + rl := newTestLimiter(t) + model := "record-existing" + + rl.State[model] = &UsageStats{ + Requests: []time.Time{time.Now()}, + Tokens: []TokenEntry{{Time: time.Now(), Count: 100}}, + DayStart: time.Now(), + DayCount: 5, + } + + rl.RecordUsage(model, 200, 300) + + stats := rl.State[model] + assert.Len(t, stats.Requests, 2) + assert.Len(t, stats.Tokens, 2) + assert.Equal(t, 6, stats.DayCount) + }) +} + +// --- Phase 0: Reset --- + +func TestReset(t *testing.T) { + t.Run("reset single model", func(t *testing.T) { + rl := newTestLimiter(t) + rl.RecordUsage("model-a", 10, 10) + rl.RecordUsage("model-b", 20, 20) + + rl.Reset("model-a") + + _, existsA := rl.State["model-a"] + _, existsB := rl.State["model-b"] + assert.False(t, existsA, "model-a state should be cleared") + assert.True(t, existsB, "model-b state should remain") + }) + + t.Run("reset all models with empty string", func(t *testing.T) { + rl := newTestLimiter(t) + rl.RecordUsage("model-a", 10, 10) + rl.RecordUsage("model-b", 20, 20) + rl.RecordUsage("model-c", 30, 30) + + rl.Reset("") + + assert.Empty(t, rl.State, "all state should be cleared") + }) + + t.Run("reset non-existent model is safe", func(t *testing.T) { + rl := newTestLimiter(t) + rl.Reset("does-not-exist") // should not panic + assert.Empty(t, rl.State) + }) +} + +// --- Phase 0: WaitForCapacity --- + +func TestWaitForCapacity(t *testing.T) { + t.Run("context cancelled returns error", func(t *testing.T) { + rl := newTestLimiter(t) + model := "wait-cancel" + rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100} + + rl.RecordUsage(model, 10, 10) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := rl.WaitForCapacity(ctx, model, 10) + assert.ErrorIs(t, err, context.DeadlineExceeded) + }) + + t.Run("immediate capacity returns nil", func(t *testing.T) { + rl := newTestLimiter(t) + model := "wait-immediate" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 100} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := rl.WaitForCapacity(ctx, model, 10) + assert.NoError(t, err, "should return immediately when capacity available") + }) + + t.Run("unknown model returns immediately", func(t *testing.T) { + rl := newTestLimiter(t) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + err := rl.WaitForCapacity(ctx, "unknown-model", 999999) + assert.NoError(t, err, "unknown model should return immediately") + }) + + t.Run("context cancelled before first check", func(t *testing.T) { + rl := newTestLimiter(t) + model := "wait-pre-cancel" + rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100} + rl.RecordUsage(model, 10, 10) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + err := rl.WaitForCapacity(ctx, model, 10) + assert.Error(t, err, "should return error for already-cancelled context") + }) +} + +// --- Phase 0: Stats --- + +func TestStats(t *testing.T) { + t.Run("returns stats for known model with usage", func(t *testing.T) { + rl := newTestLimiter(t) + model := "stats-test" + rl.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500} + rl.RecordUsage(model, 100, 200) + + stats := rl.Stats(model) + assert.Equal(t, 1, stats.RPM) + assert.Equal(t, 300, stats.TPM) + assert.Equal(t, 1, stats.RPD) + assert.Equal(t, 50, stats.MaxRPM) + assert.Equal(t, 5000, stats.MaxTPM) + assert.Equal(t, 500, stats.MaxRPD) + }) + + t.Run("returns empty stats for unknown model", func(t *testing.T) { + rl := newTestLimiter(t) + stats := rl.Stats("nonexistent") + assert.Equal(t, 0, stats.RPM) + assert.Equal(t, 0, stats.TPM) + assert.Equal(t, 0, stats.RPD) + assert.Equal(t, 0, stats.MaxRPM) + }) + + t.Run("returns quota info for model without usage", func(t *testing.T) { + rl := newTestLimiter(t) + model := "quota-only" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 200, MaxRPD: 300} + + stats := rl.Stats(model) + assert.Equal(t, 0, stats.RPM, "no usage yet") + assert.Equal(t, 100, stats.MaxRPM, "quota should be present") + assert.Equal(t, 200, stats.MaxTPM) + assert.Equal(t, 300, stats.MaxRPD) + }) +} + +// --- Phase 0: AllStats --- + +func TestAllStats(t *testing.T) { + t.Run("includes all default quotas plus state-only models", func(t *testing.T) { + rl := newTestLimiter(t) + rl.RecordUsage("gemini-3-pro-preview", 1000, 500) + rl.RecordUsage("custom-model", 100, 200) + + all := rl.AllStats() + // Should include all default Gemini models + custom-model + assert.GreaterOrEqual(t, len(all), 6, "should include default quotas + custom model") + + pro := all["gemini-3-pro-preview"] + assert.Equal(t, 1, pro.RPM) + assert.Equal(t, 1500, pro.TPM) + + custom := all["custom-model"] + assert.Equal(t, 1, custom.RPM) + assert.Equal(t, 300, custom.TPM) + assert.Equal(t, 0, custom.MaxRPM, "custom model has no quota") + }) + + t.Run("prunes old entries in AllStats", func(t *testing.T) { + rl := newTestLimiter(t) + model := "allstats-prune" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 1000000, MaxRPD: 100} + + oldTime := time.Now().Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{oldTime, oldTime}, + Tokens: []TokenEntry{{Time: oldTime, Count: 500}}, + DayStart: time.Now(), + DayCount: 2, + } + + all := rl.AllStats() + stats := all[model] + assert.Equal(t, 0, stats.RPM, "old requests should be pruned") + assert.Equal(t, 0, stats.TPM, "old tokens should be pruned") + assert.Equal(t, 2, stats.RPD, "daily count survives prune") + }) + + t.Run("daily reset in AllStats", func(t *testing.T) { + rl := newTestLimiter(t) + model := "allstats-daily" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 1000000, MaxRPD: 100} + + rl.State[model] = &UsageStats{ + DayStart: time.Now().Add(-25 * time.Hour), + DayCount: 50, + } + + all := rl.AllStats() + stats := all[model] + assert.Equal(t, 0, stats.RPD, "daily count should be reset after 24h") + }) +} + +// --- Phase 0: Persist and Load --- + +func TestPersistAndLoad(t *testing.T) { + t.Run("round-trip preserves state", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "ratelimits.yaml") + + rl1, err := New() + require.NoError(t, err) + rl1.filePath = path + model := "persist-test" + rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500} + rl1.RecordUsage(model, 100, 100) + + require.NoError(t, rl1.Persist()) + + rl2, err := New() + require.NoError(t, err) + rl2.filePath = path + require.NoError(t, rl2.Load()) + + stats := rl2.Stats(model) + assert.Equal(t, 1, stats.RPM) + assert.Equal(t, 200, stats.TPM) + }) + + t.Run("load from non-existent file is not an error", func(t *testing.T) { + rl := newTestLimiter(t) + rl.filePath = filepath.Join(t.TempDir(), "does-not-exist.yaml") + + err := rl.Load() + assert.NoError(t, err, "loading non-existent file should not error") + }) + + t.Run("load from corrupt YAML returns error", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "corrupt.yaml") + require.NoError(t, os.WriteFile(path, []byte("{{{{invalid yaml!!!!"), 0644)) + + rl := newTestLimiter(t) + rl.filePath = path + + err := rl.Load() + assert.Error(t, err, "corrupt YAML should produce an error") + }) + + t.Run("load from unreadable file returns error", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "unreadable.yaml") + require.NoError(t, os.WriteFile(path, []byte("quotas: {}"), 0644)) + require.NoError(t, os.Chmod(path, 0000)) + + rl := newTestLimiter(t) + rl.filePath = path + + err := rl.Load() + assert.Error(t, err, "unreadable file should produce an error") + + // Clean up permissions for temp dir cleanup + _ = os.Chmod(path, 0644) + }) + + t.Run("persist to nested non-existent directory creates it", func(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "nested", "deep", "ratelimits.yaml") + + rl := newTestLimiter(t) + rl.filePath = path + rl.RecordUsage("test", 1, 1) + + err := rl.Persist() + assert.NoError(t, err, "should create nested directories") + + _, statErr := os.Stat(path) + assert.NoError(t, statErr, "file should exist") + }) + + t.Run("persist to unwritable directory returns error", func(t *testing.T) { + tmpDir := t.TempDir() + unwritable := filepath.Join(tmpDir, "readonly") + require.NoError(t, os.MkdirAll(unwritable, 0555)) + + rl := newTestLimiter(t) + rl.filePath = filepath.Join(unwritable, "sub", "ratelimits.yaml") + + err := rl.Persist() + assert.Error(t, err, "should fail when directory is unwritable") + + // Clean up + _ = os.Chmod(unwritable, 0755) + }) +} + +// --- Phase 0: Default quotas --- + +func TestDefaultQuotas(t *testing.T) { + rl := newTestLimiter(t) + + tests := []struct { + model string + maxRPM int + maxTPM int + maxRPD int + }{ + {"gemini-3-pro-preview", 150, 1000000, 1000}, + {"gemini-3-flash-preview", 150, 1000000, 1000}, + {"gemini-2.5-pro", 150, 1000000, 1000}, + {"gemini-2.0-flash", 150, 1000000, 0}, + {"gemini-2.0-flash-lite", 0, 0, 0}, + } + + for _, tt := range tests { + t.Run(tt.model, func(t *testing.T) { + q, ok := rl.Quotas[tt.model] + require.True(t, ok, "quota should exist for %s", tt.model) + assert.Equal(t, tt.maxRPM, q.MaxRPM) + assert.Equal(t, tt.maxTPM, q.MaxTPM) + assert.Equal(t, tt.maxRPD, q.MaxRPD) + }) } } -func TestCanSend_RPMExceeded_Bad(t *testing.T) { - rl, _ := New() - model := "test-rpm" - rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100} +// --- Phase 0: Concurrent access (race test) --- - rl.RecordUsage(model, 10, 10) - rl.RecordUsage(model, 10, 10) +func TestConcurrentAccess(t *testing.T) { + rl := newTestLimiter(t) + model := "concurrent-test" + rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 10000} - if rl.CanSend(model, 10) { - t.Errorf("Expected CanSend to return false after exceeding RPM") + var wg sync.WaitGroup + goroutines := 20 + opsPerGoroutine := 50 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < opsPerGoroutine; j++ { + rl.CanSend(model, 10) + rl.RecordUsage(model, 5, 5) + rl.Stats(model) + } + }() } + + wg.Wait() + + stats := rl.Stats(model) + expected := goroutines * opsPerGoroutine + assert.Equal(t, expected, stats.RPD, "all recordings should be counted") } -func TestCanSend_TPMExceeded_Bad(t *testing.T) { - rl, _ := New() - model := "test-tpm" - rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100} +func TestConcurrentResetAndRecord(t *testing.T) { + rl := newTestLimiter(t) + model := "concurrent-reset" + rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000} - rl.RecordUsage(model, 50, 40) // 90 tokens used + var wg sync.WaitGroup - if rl.CanSend(model, 20) { // 90 + 20 = 110 > 100 - t.Errorf("Expected CanSend to return false when estimated tokens exceed TPM") + // Writers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + rl.RecordUsage(model, 1, 1) + } + }() } + + // Resetters + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 20; j++ { + rl.Reset(model) + } + }() + } + + // Readers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + rl.AllStats() + } + }() + } + + wg.Wait() + // No assertion needed -- if we get here without -race flagging, mutex is sound } -func TestCanSend_RPDExceeded_Bad(t *testing.T) { - rl, _ := New() - model := "test-rpd" - rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2} +// --- Phase 0: CountTokens (with mock HTTP server) --- - rl.RecordUsage(model, 10, 10) - rl.RecordUsage(model, 10, 10) +func TestCountTokens(t *testing.T) { + t.Run("successful token count", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key")) - if rl.CanSend(model, 10) { - t.Errorf("Expected CanSend to return false after exceeding RPD") - } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}) + })) + defer server.Close() + + // We need to override the URL. Since CountTokens hardcodes the Google API URL, + // we test it via the exported function with a test server. + // For proper unit testing, we would need to make the base URL configurable. + // For now, test the error paths that don't require a real API. + }) + + t.Run("API error returns error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `{"error": "invalid API key"}`) + })) + defer server.Close() + + // Can't test directly due to hardcoded URL, but we can verify error + // handling with an unreachable endpoint + _, err := CountTokens("fake-key", "test-model", "hello") + assert.Error(t, err, "should fail with invalid API endpoint") + }) } -func TestCanSend_UnlimitedModel_Good(t *testing.T) { - rl, _ := New() - model := "test-unlimited" - rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} +// --- Phase 0: Benchmarks --- - // Should always be allowed +func BenchmarkCanSend(b *testing.B) { + rl, _ := New() + model := "bench-model" + rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000} + + // Populate sliding window with 1000 entries + now := time.Now() + entries := make([]time.Time, 1000) + tokens := make([]TokenEntry, 1000) for i := 0; i < 1000; i++ { + t := now.Add(-time.Duration(i) * time.Millisecond * 50) // spread over ~50 seconds + entries[i] = t + tokens[i] = TokenEntry{Time: t, Count: 10} + } + rl.State[model] = &UsageStats{ + Requests: entries, + Tokens: tokens, + DayStart: now, + DayCount: 500, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + rl.CanSend(model, 100) + } +} + +func BenchmarkRecordUsage(b *testing.B) { + rl, _ := New() + model := "bench-record" + rl.Quotas[model] = ModelQuota{MaxRPM: 100000, MaxTPM: 1000000000, MaxRPD: 1000000} + + b.ResetTimer() + for i := 0; i < b.N; i++ { rl.RecordUsage(model, 100, 100) } - if !rl.CanSend(model, 999999) { - t.Errorf("Expected unlimited model to always allow sends") - } } -func TestRecordUsage_PrunesOldEntries_Good(t *testing.T) { +func BenchmarkCanSendConcurrent(b *testing.B) { rl, _ := New() - model := "test-prune" - rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100} + model := "bench-concurrent" + rl.Quotas[model] = ModelQuota{MaxRPM: 100000, MaxTPM: 1000000000, MaxRPD: 1000000} - // Manually inject old data - oldTime := time.Now().Add(-2 * time.Minute) - rl.State[model] = &UsageStats{ - Requests: []time.Time{oldTime, oldTime, oldTime}, - Tokens: []TokenEntry{ - {Time: oldTime, Count: 100}, - {Time: oldTime, Count: 100}, - }, - DayStart: time.Now(), + // Populate window + now := time.Now() + for i := 0; i < 100; i++ { + rl.State[model] = &UsageStats{DayStart: now} + rl.RecordUsage(model, 10, 10) } - // CanSend triggers prune - if !rl.CanSend(model, 10) { - t.Errorf("Expected CanSend to return true after pruning old entries") - } - - stats := rl.State[model] - if len(stats.Requests) != 0 { - t.Errorf("Expected 0 requests after pruning old entries, got %d", len(stats.Requests)) - } -} - -func TestPersistAndLoad_Good(t *testing.T) { - tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "ratelimits.yaml") - - rl1, _ := New() - rl1.filePath = path - model := "persist-test" - rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500} - rl1.RecordUsage(model, 100, 100) - - if err := rl1.Persist(); err != nil { - t.Fatalf("Persist failed: %v", err) - } - - rl2, _ := New() - rl2.filePath = path - if err := rl2.Load(); err != nil { - t.Fatalf("Load failed: %v", err) - } - - stats := rl2.Stats(model) - if stats.RPM != 1 { - t.Errorf("Expected RPM 1 after load, got %d", stats.RPM) - } - if stats.TPM != 200 { - t.Errorf("Expected TPM 200 after load, got %d", stats.TPM) - } -} - -func TestWaitForCapacity_Ugly(t *testing.T) { - rl, _ := New() - model := "wait-test" - rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100} - - rl.RecordUsage(model, 10, 10) // Use up the 1 RPM - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - err := rl.WaitForCapacity(ctx, model, 10) - if err != context.DeadlineExceeded { - t.Errorf("Expected DeadlineExceeded, got %v", err) - } -} - -func TestDefaultQuotas_Good(t *testing.T) { - rl, _ := New() - expected := []string{ - "gemini-3-pro-preview", - "gemini-3-flash-preview", - "gemini-2.0-flash", - } - for _, m := range expected { - if _, ok := rl.Quotas[m]; !ok { - t.Errorf("Expected default quota for %s", m) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + rl.CanSend(model, 100) } + }) +} + +// --- Phase 1: Provider profiles and NewWithConfig --- + +func TestDefaultProfiles(t *testing.T) { + profiles := DefaultProfiles() + + t.Run("contains all four providers", func(t *testing.T) { + assert.Contains(t, profiles, ProviderGemini) + assert.Contains(t, profiles, ProviderOpenAI) + assert.Contains(t, profiles, ProviderAnthropic) + assert.Contains(t, profiles, ProviderLocal) + }) + + t.Run("Gemini profile has expected models", func(t *testing.T) { + gemini := profiles[ProviderGemini] + assert.Equal(t, ProviderGemini, gemini.Provider) + assert.Contains(t, gemini.Models, "gemini-3-pro-preview") + assert.Contains(t, gemini.Models, "gemini-2.0-flash") + assert.Contains(t, gemini.Models, "gemini-2.0-flash-lite") + }) + + t.Run("OpenAI profile has expected models", func(t *testing.T) { + openai := profiles[ProviderOpenAI] + assert.Equal(t, ProviderOpenAI, openai.Provider) + assert.Contains(t, openai.Models, "gpt-4o") + assert.Contains(t, openai.Models, "gpt-4o-mini") + assert.Contains(t, openai.Models, "o3-mini") + }) + + t.Run("Anthropic profile has expected models", func(t *testing.T) { + anthropic := profiles[ProviderAnthropic] + assert.Equal(t, ProviderAnthropic, anthropic.Provider) + assert.Contains(t, anthropic.Models, "claude-opus-4") + assert.Contains(t, anthropic.Models, "claude-sonnet-4") + assert.Contains(t, anthropic.Models, "claude-haiku-3.5") + }) + + t.Run("Local profile has no models by default", func(t *testing.T) { + local := profiles[ProviderLocal] + assert.Equal(t, ProviderLocal, local.Provider) + assert.Empty(t, local.Models, "local provider should have no default quotas") + }) +} + +func TestNewWithConfig(t *testing.T) { + t.Run("empty config defaults to Gemini", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + }) + require.NoError(t, err) + + _, hasGemini := rl.Quotas["gemini-3-pro-preview"] + assert.True(t, hasGemini, "empty config should load Gemini defaults") + }) + + t.Run("single provider loads only its models", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Providers: []Provider{ProviderOpenAI}, + }) + require.NoError(t, err) + + _, hasGPT := rl.Quotas["gpt-4o"] + assert.True(t, hasGPT, "should have OpenAI models") + + _, hasGemini := rl.Quotas["gemini-3-pro-preview"] + assert.False(t, hasGemini, "should not have Gemini models") + }) + + t.Run("multiple providers merge models", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Providers: []Provider{ProviderGemini, ProviderAnthropic}, + }) + require.NoError(t, err) + + _, hasGemini := rl.Quotas["gemini-3-pro-preview"] + _, hasClaude := rl.Quotas["claude-opus-4"] + assert.True(t, hasGemini, "should have Gemini models") + assert.True(t, hasClaude, "should have Anthropic models") + + _, hasGPT := rl.Quotas["gpt-4o"] + assert.False(t, hasGPT, "should not have OpenAI models") + }) + + t.Run("explicit quotas override provider defaults", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Providers: []Provider{ProviderGemini}, + Quotas: map[string]ModelQuota{ + "gemini-3-pro-preview": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777}, + }, + }) + require.NoError(t, err) + + q := rl.Quotas["gemini-3-pro-preview"] + assert.Equal(t, 999, q.MaxRPM, "explicit quota should override provider default") + assert.Equal(t, 888, q.MaxTPM) + assert.Equal(t, 777, q.MaxRPD) + }) + + t.Run("explicit quotas without providers", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Quotas: map[string]ModelQuota{ + "my-custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, + }, + }) + require.NoError(t, err) + + assert.Len(t, rl.Quotas, 1, "should only have the explicit quota") + q := rl.Quotas["my-custom-model"] + assert.Equal(t, 10, q.MaxRPM) + }) + + t.Run("custom file path is respected", func(t *testing.T) { + customPath := filepath.Join(t.TempDir(), "custom", "limits.yaml") + rl, err := NewWithConfig(Config{ + FilePath: customPath, + Providers: []Provider{ProviderLocal}, + }) + require.NoError(t, err) + + rl.RecordUsage("test", 1, 1) + require.NoError(t, rl.Persist()) + + _, statErr := os.Stat(customPath) + assert.NoError(t, statErr, "file should be created at custom path") + }) + + t.Run("unknown provider is silently skipped", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Providers: []Provider{"nonexistent-provider"}, + }) + require.NoError(t, err) + assert.Empty(t, rl.Quotas, "unknown provider should produce no quotas") + }) + + t.Run("local provider with custom quotas", func(t *testing.T) { + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "test.yaml"), + Providers: []Provider{ProviderLocal}, + Quotas: map[string]ModelQuota{ + "llama-3.3-70b": {MaxRPM: 5, MaxTPM: 50000, MaxRPD: 0}, + }, + }) + require.NoError(t, err) + + assert.Len(t, rl.Quotas, 1, "should only have the custom local model") + q := rl.Quotas["llama-3.3-70b"] + assert.Equal(t, 5, q.MaxRPM) + assert.Equal(t, 50000, q.MaxTPM) + }) +} + +func TestNewBackwardCompatibility(t *testing.T) { + // New() should produce the exact same result as before Phase 1 + rl, err := New() + require.NoError(t, err) + + expected := map[string]ModelQuota{ + "gemini-3-pro-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-3-flash-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-2.5-pro": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}, + "gemini-2.0-flash": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0}, + "gemini-2.0-flash-lite": {MaxRPM: 0, MaxTPM: 0, MaxRPD: 0}, + } + + assert.Len(t, rl.Quotas, len(expected), "New() should produce exactly the Gemini defaults") + for model, expectedQ := range expected { + t.Run(model, func(t *testing.T) { + q, ok := rl.Quotas[model] + require.True(t, ok, "quota should exist") + assert.Equal(t, expectedQ, q) + }) } } -func TestAllStats_Good(t *testing.T) { - rl, _ := New() - rl.RecordUsage("gemini-3-pro-preview", 1000, 500) +func TestSetQuota(t *testing.T) { + t.Run("adds new model quota", func(t *testing.T) { + rl := newTestLimiter(t) + rl.SetQuota("custom-model", ModelQuota{MaxRPM: 42, MaxTPM: 9999, MaxRPD: 100}) - all := rl.AllStats() - if len(all) < 5 { - t.Errorf("Expected at least 5 models in AllStats, got %d", len(all)) - } + q, ok := rl.Quotas["custom-model"] + require.True(t, ok) + assert.Equal(t, 42, q.MaxRPM) + assert.Equal(t, 9999, q.MaxTPM) + assert.Equal(t, 100, q.MaxRPD) + }) - pro := all["gemini-3-pro-preview"] - if pro.RPM != 1 { - t.Errorf("Expected RPM 1 for pro, got %d", pro.RPM) - } - if pro.TPM != 1500 { - t.Errorf("Expected TPM 1500 for pro, got %d", pro.TPM) - } + t.Run("overwrites existing model quota", func(t *testing.T) { + rl := newTestLimiter(t) + rl.SetQuota("gemini-3-pro-preview", ModelQuota{MaxRPM: 1, MaxTPM: 1, MaxRPD: 1}) + + q := rl.Quotas["gemini-3-pro-preview"] + assert.Equal(t, 1, q.MaxRPM, "should overwrite existing") + }) + + t.Run("is safe for concurrent use", func(t *testing.T) { + rl := newTestLimiter(t) + var wg sync.WaitGroup + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + model := fmt.Sprintf("model-%d", n) + rl.SetQuota(model, ModelQuota{MaxRPM: n, MaxTPM: n * 100, MaxRPD: n * 10}) + }(i) + } + wg.Wait() + + assert.GreaterOrEqual(t, len(rl.Quotas), 10, "all concurrent SetQuota calls should succeed") + }) +} + +func TestAddProvider(t *testing.T) { + t.Run("adds OpenAI models to existing limiter", func(t *testing.T) { + rl := newTestLimiter(t) // starts with Gemini defaults + geminiCount := len(rl.Quotas) + + rl.AddProvider(ProviderOpenAI) + + // Should now have both Gemini and OpenAI models + assert.Greater(t, len(rl.Quotas), geminiCount, "should have more models after adding OpenAI") + _, hasGPT := rl.Quotas["gpt-4o"] + assert.True(t, hasGPT, "should have OpenAI models") + + // Gemini models should still be present + _, hasGemini := rl.Quotas["gemini-3-pro-preview"] + assert.True(t, hasGemini, "Gemini models should remain") + }) + + t.Run("adds Anthropic models to existing limiter", func(t *testing.T) { + rl := newTestLimiter(t) + rl.AddProvider(ProviderAnthropic) + + _, hasClaude := rl.Quotas["claude-opus-4"] + assert.True(t, hasClaude, "should have Anthropic models") + }) + + t.Run("unknown provider is a noop", func(t *testing.T) { + rl := newTestLimiter(t) + before := len(rl.Quotas) + + rl.AddProvider("nonexistent") + + assert.Equal(t, before, len(rl.Quotas), "unknown provider should not change quotas") + }) + + t.Run("adding local provider does not remove existing quotas", func(t *testing.T) { + rl := newTestLimiter(t) + before := len(rl.Quotas) + + rl.AddProvider(ProviderLocal) + + assert.Equal(t, before, len(rl.Quotas), "local provider has no models, count unchanged") + }) + + t.Run("is safe for concurrent use", func(t *testing.T) { + rl := newTestLimiter(t) + var wg sync.WaitGroup + + providers := []Provider{ProviderGemini, ProviderOpenAI, ProviderAnthropic, ProviderLocal} + for _, p := range providers { + wg.Add(1) + go func(prov Provider) { + defer wg.Done() + for i := 0; i < 10; i++ { + rl.AddProvider(prov) + } + }(p) + } + wg.Wait() + // Should not panic + }) +} + +func TestProviderConstants(t *testing.T) { + // Verify the string values are stable (they may be used in YAML configs) + assert.Equal(t, Provider("gemini"), ProviderGemini) + assert.Equal(t, Provider("openai"), ProviderOpenAI) + assert.Equal(t, Provider("anthropic"), ProviderAnthropic) + assert.Equal(t, Provider("local"), ProviderLocal) +} + +func TestEndToEndMultiProvider(t *testing.T) { + // Simulate a real-world scenario: limiter for both Gemini and Anthropic + rl, err := NewWithConfig(Config{ + FilePath: filepath.Join(t.TempDir(), "multi.yaml"), + Providers: []Provider{ProviderGemini, ProviderAnthropic}, + }) + require.NoError(t, err) + + // Use Gemini model + assert.True(t, rl.CanSend("gemini-3-pro-preview", 1000)) + rl.RecordUsage("gemini-3-pro-preview", 500, 500) + + // Use Anthropic model + assert.True(t, rl.CanSend("claude-opus-4", 1000)) + rl.RecordUsage("claude-opus-4", 500, 500) + + // Check stats for both + geminiStats := rl.Stats("gemini-3-pro-preview") + assert.Equal(t, 1, geminiStats.RPM) + assert.Equal(t, 150, geminiStats.MaxRPM) + + claudeStats := rl.Stats("claude-opus-4") + assert.Equal(t, 1, claudeStats.RPM) + assert.Equal(t, 50, claudeStats.MaxRPM) + + // Persist and reload + require.NoError(t, rl.Persist()) + + rl2, err := NewWithConfig(Config{ + FilePath: rl.filePath, + Providers: []Provider{ProviderGemini, ProviderAnthropic}, + }) + require.NoError(t, err) + require.NoError(t, rl2.Load()) + + reloaded := rl2.Stats("claude-opus-4") + assert.Equal(t, 1, reloaded.RPM, "state should survive persist/reload") }