fix(ratelimit): harden audit edge cases

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-23 07:26:15 +00:00
parent 0db35c4ce9
commit d1c90b937d
5 changed files with 715 additions and 133 deletions

View file

@ -31,13 +31,19 @@ func TestIterators(t *testing.T) {
assert.Contains(t, models, "model-a") assert.Contains(t, models, "model-a")
assert.Contains(t, models, "model-b") assert.Contains(t, models, "model-b")
assert.Contains(t, models, "model-c") assert.Contains(t, models, "model-c")
// Check sorting of our specific models // Check sorting of our specific models
foundA, foundB, foundC := -1, -1, -1 foundA, foundB, foundC := -1, -1, -1
for i, m := range models { for i, m := range models {
if m == "model-a" { foundA = i } if m == "model-a" {
if m == "model-b" { foundB = i } foundA = i
if m == "model-c" { foundC = i } }
if m == "model-b" {
foundB = i
}
if m == "model-c" {
foundC = i
}
} }
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c") assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
}) })
@ -57,9 +63,15 @@ func TestIterators(t *testing.T) {
// Check sorting // Check sorting
foundA, foundB, foundC := -1, -1, -1 foundA, foundB, foundC := -1, -1, -1
for i, m := range models { for i, m := range models {
if m == "model-a" { foundA = i } if m == "model-a" {
if m == "model-b" { foundB = i } foundA = i
if m == "model-c" { foundC = i } }
if m == "model-b" {
foundB = i
}
if m == "model-c" {
foundC = i
}
} }
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c") assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
}) })
@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) {
} }
func TestCountTokensFull(t *testing.T) { func TestCountTokensFull(t *testing.T) {
t.Run("invalid URL/network error", func(t *testing.T) { t.Run("empty model is rejected", func(t *testing.T) {
// Using an invalid character in model name to trigger URL error or similar _, err := CountTokens(context.Background(), "key", "", "text")
_, err := CountTokens(context.Background(), "key", "invalid model", "text")
assert.Error(t, err) assert.Error(t, err)
}) })
@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
// We can't easily override the URL in CountTokens without changing the code, _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text")
// but we can test the logic if we make it slightly more testable. require.Error(t, err)
// For now, I've already updated ratelimit_test.go with some of this. assert.Contains(t, err.Error(), "status 400")
}) })
t.Run("context cancelled", func(t *testing.T) { t.Run("context cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
cancel() cancel()
_, err := CountTokens(ctx, "key", "model", "text") _, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text")
assert.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "do request") assert.Contains(t, err.Error(), "do request")
}) })

View file

@ -9,9 +9,11 @@ import (
"iter" "iter"
"maps" "maps"
"net/http" "net/http"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strings"
"sync" "sync"
"time" "time"
@ -34,6 +36,16 @@ const (
ProviderLocal Provider = "local" ProviderLocal Provider = "local"
) )
const (
defaultStateDirName = ".core"
defaultYAMLStateFile = "ratelimits.yaml"
defaultSQLiteStateFile = "ratelimits.db"
backendYAML = "yaml"
backendSQLite = "sqlite"
countTokensErrorBodyLimit = 8 * 1024
countTokensSuccessBodyLimit = 1 * 1024 * 1024
)
// ModelQuota defines the rate limits for a specific model. // ModelQuota defines the rate limits for a specific model.
type ModelQuota struct { type ModelQuota struct {
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited) MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
@ -143,39 +155,30 @@ func New() (*RateLimiter, error) {
// NewWithConfig creates a RateLimiter from explicit configuration. // NewWithConfig creates a RateLimiter from explicit configuration.
// If no providers or quotas are specified, Gemini defaults are used. // If no providers or quotas are specified, Gemini defaults are used.
func NewWithConfig(cfg Config) (*RateLimiter, error) { func NewWithConfig(cfg Config) (*RateLimiter, error) {
backend, err := normaliseBackend(cfg.Backend)
if err != nil {
return nil, err
}
filePath := cfg.FilePath filePath := cfg.FilePath
if filePath == "" { if filePath == "" {
home, err := os.UserHomeDir() filePath, err = defaultStatePath(backend)
if err != nil { if err != nil {
return nil, err return nil, err
} }
filePath = filepath.Join(home, ".core", "ratelimits.yaml")
} }
rl := &RateLimiter{ if backend == backendSQLite {
Quotas: make(map[string]ModelQuota), if cfg.FilePath == "" {
State: make(map[string]*UsageStats), if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
filePath: filePath, return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
} }
// Load provider profiles
profiles := DefaultProfiles()
providers := cfg.Providers
// If nothing specified at all, default to Gemini
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
} }
return NewWithSQLiteConfig(filePath, cfg)
} }
// Merge explicit quotas on top (allows overrides) rl := newConfiguredRateLimiter(cfg)
maps.Copy(rl.Quotas, cfg.Quotas) rl.filePath = filePath
return rl, nil return rl, nil
} }
@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error {
if err != nil { if err != nil {
return err return err
} }
// Merge loaded quotas (loaded quotas override in-memory defaults). // Persisted quotas are authoritative when present; otherwise keep config defaults.
maps.Copy(rl.Quotas, quotas) if len(quotas) > 0 {
rl.Quotas = maps.Clone(quotas)
}
state, err := rl.sqlite.loadState() state, err := rl.sqlite.loadState()
if err != nil { if err != nil {
return err return err
} }
// Replace in-memory state with persisted state. // Replace in-memory state with the persisted snapshot.
maps.Copy(rl.State, state) rl.State = state
return nil return nil
} }
@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error {
quotas := maps.Clone(rl.Quotas) quotas := maps.Clone(rl.Quotas)
state := make(map[string]*UsageStats, len(rl.State)) state := make(map[string]*UsageStats, len(rl.State))
for k, v := range rl.State { for k, v := range rl.State {
if v == nil {
continue
}
state[k] = &UsageStats{ state[k] = &UsageStats{
Requests: slices.Clone(v.Requests), Requests: slices.Clone(v.Requests),
Tokens: slices.Clone(v.Tokens), Tokens: slices.Clone(v.Tokens),
@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error {
rl.mu.Unlock() rl.mu.Unlock()
if sqlite != nil { if sqlite != nil {
if err := sqlite.saveQuotas(quotas); err != nil { if err := sqlite.saveSnapshot(quotas, state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite quotas", err) return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
}
if err := sqlite.saveState(state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite state", err)
} }
return nil return nil
} }
@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) {
if !ok { if !ok {
return return
} }
if stats == nil {
delete(rl.State, model)
return
}
now := time.Now() now := time.Now()
window := now.Add(-1 * time.Minute) window := now.Add(-1 * time.Minute)
@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) {
// BackgroundPrune starts a goroutine that periodically prunes all model states. // BackgroundPrune starts a goroutine that periodically prunes all model states.
// It returns a function to stop the pruner. // It returns a function to stop the pruner.
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
if interval <= 0 {
return func() {}
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
go func() { go func() {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
// CanSend checks if a request can be sent without violating limits. // CanSend checks if a request can be sent without violating limits.
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
if estimatedTokens < 0 {
return false
}
rl.mu.Lock() rl.mu.Lock()
defer rl.mu.Unlock() defer rl.mu.Unlock()
@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
// Check TPM // Check TPM
if quota.MaxTPM > 0 { if quota.MaxTPM > 0 {
currentTokens := 0 currentTokens := totalTokenCount(stats.Tokens)
for _, t := range stats.Tokens { if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
currentTokens += t.Count
}
if currentTokens+estimatedTokens > quota.MaxTPM {
return false return false
} }
} }
@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
} }
now := time.Now() now := time.Now()
totalTokens := safeTokenSum(promptTokens, outputTokens)
stats.Requests = append(stats.Requests, now) stats.Requests = append(stats.Requests, now)
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens})
stats.DayCount++ stats.DayCount++
} }
// WaitForCapacity blocks until capacity is available or context is cancelled. // WaitForCapacity blocks until capacity is available or context is cancelled.
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
if tokens < 0 {
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
}
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
result[m] = ModelStats{} result[m] = ModelStats{}
} }
now := time.Now()
window := now.Add(-1 * time.Minute)
for m := range result { for m := range result {
// Prune inline rl.prune(m)
if s, ok := rl.State[m]; ok {
s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool {
return !t.After(window)
})
s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool {
return !t.Time.After(window)
})
if now.Sub(s.DayStart) >= 24*time.Hour {
s.DayStart = now
s.DayCount = 0
}
}
ms := ModelStats{} ms := ModelStats{}
if q, ok := rl.Quotas[m]; ok { if q, ok := rl.Quotas[m]; ok {
@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
ms.MaxTPM = q.MaxTPM ms.MaxTPM = q.MaxTPM
ms.MaxRPD = q.MaxRPD ms.MaxRPD = q.MaxRPD
} }
if s, ok := rl.State[m]; ok { if s, ok := rl.State[m]; ok && s != nil {
ms.RPM = len(s.Requests) ms.RPM = len(s.Requests)
ms.RPD = s.DayCount ms.RPD = s.DayCount
ms.DayStart = s.DayStart ms.DayStart = s.DayStart
for _, t := range s.Tokens { ms.TPM = totalTokenCount(s.Tokens)
ms.TPM += t.Count
}
} }
result[m] = ms result[m] = ms
} }
@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
return nil, err return nil, err
} }
rl := &RateLimiter{ rl := newConfiguredRateLimiter(cfg)
Quotas: make(map[string]ModelQuota), rl.sqlite = store
State: make(map[string]*UsageStats),
sqlite: store,
}
// Load provider profiles.
profiles := DefaultProfiles()
providers := cfg.Providers
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
maps.Copy(rl.Quotas, cfg.Quotas)
return rl, nil return rl, nil
} }
@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
} }
defer store.close() defer store.close()
if rl.Quotas != nil { if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil {
if err := store.saveQuotas(rl.Quotas); err != nil { return err
return err
}
}
if rl.State != nil {
if err := store.saveState(rl.State); err != nil {
return err
}
} }
return nil return nil
} }
// CountTokens calls the Google API to count tokens for a prompt. // CountTokens calls the Google API to count tokens for a prompt.
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model) return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
}
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
requestURL, err := countTokensURL(baseURL, model)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
}
reqBody := map[string]any{ reqBody := map[string]any{
"contents": []any{ "contents": []any{
@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err) return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "new request", err) return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey) req.Header.Set("x-goog-api-key", apiKey)
resp, err := http.DefaultClient.Do(req) if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "do request", err) return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil) if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
}
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
} }
var result struct { var result struct {
TotalTokens int `json:"totalTokens"` TotalTokens int `json:"totalTokens"`
} }
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err) return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
} }
return result.TotalTokens, nil return result.TotalTokens, nil
} }
func newConfiguredRateLimiter(cfg Config) *RateLimiter {
rl := &RateLimiter{
Quotas: make(map[string]ModelQuota),
State: make(map[string]*UsageStats),
}
applyConfig(rl, cfg)
return rl
}
func applyConfig(rl *RateLimiter, cfg Config) {
profiles := DefaultProfiles()
providers := cfg.Providers
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
maps.Copy(rl.Quotas, cfg.Quotas)
}
func normaliseBackend(backend string) (string, error) {
switch strings.ToLower(strings.TrimSpace(backend)) {
case "", backendYAML:
return backendYAML, nil
case backendSQLite:
return backendSQLite, nil
default:
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
}
}
func defaultStatePath(backend string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
fileName := defaultYAMLStateFile
if backend == backendSQLite {
fileName = defaultSQLiteStateFile
}
return filepath.Join(home, defaultStateDirName, fileName), nil
}
func safeTokenSum(a, b int) int {
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
}
func totalTokenCount(tokens []TokenEntry) int {
return safeTokenTotal(tokens)
}
func safeTokenTotal(tokens []TokenEntry) int {
const maxInt = int(^uint(0) >> 1)
total := 0
for _, token := range tokens {
count := token.Count
if count < 0 {
continue
}
if total > maxInt-count {
return maxInt
}
total += count
}
return total
}
func countTokensURL(baseURL, model string) (string, error) {
if strings.TrimSpace(model) == "" {
return "", fmt.Errorf("empty model")
}
parsed, err := url.Parse(baseURL)
if err != nil {
return "", err
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid base url")
}
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
}
func readLimitedBody(r io.Reader, limit int64) (string, error) {
body, err := io.ReadAll(io.LimitReader(r, limit+1))
if err != nil {
return "", err
}
truncated := int64(len(body)) > limit
if truncated {
body = body[:limit]
}
result := string(body)
if truncated {
result += "..."
}
return result, nil
}

View file

@ -4,10 +4,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter {
return rl return rl
} }
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type errReader struct{}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
// --- Phase 0: CanSend boundary conditions --- // --- Phase 0: CanSend boundary conditions ---
func TestCanSend(t *testing.T) { func TestCanSend(t *testing.T) {
@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) {
rl.RecordUsage(model, 50, 50) // exactly 100 tokens rl.RecordUsage(model, 50, 50) // exactly 100 tokens
assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit") assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit")
}) })
t.Run("negative estimated tokens are rejected", func(t *testing.T) {
rl := newTestLimiter(t)
model := "negative-est"
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100}
rl.RecordUsage(model, 50, 50)
assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits")
})
} }
// --- Phase 0: Sliding window / prune tests --- // --- Phase 0: Sliding window / prune tests ---
@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) {
assert.Len(t, stats.Tokens, 2) assert.Len(t, stats.Tokens, 2)
assert.Equal(t, 6, stats.DayCount) assert.Equal(t, 6, stats.DayCount)
}) })
t.Run("negative token inputs are clamped to zero", func(t *testing.T) {
rl := newTestLimiter(t)
model := "record-negative"
rl.RecordUsage(model, -100, 25)
stats := rl.State[model]
require.NotNil(t, stats)
assert.Len(t, stats.Requests, 1)
assert.Len(t, stats.Tokens, 1)
assert.Equal(t, 25, stats.Tokens[0].Count)
})
} }
// --- Phase 0: Reset --- // --- Phase 0: Reset ---
@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) {
err := rl.WaitForCapacity(ctx, model, 10) err := rl.WaitForCapacity(ctx, model, 10)
assert.Error(t, err, "should return error for already-cancelled context") assert.Error(t, err, "should return error for already-cancelled context")
}) })
t.Run("negative tokens return error", func(t *testing.T) {
rl := newTestLimiter(t)
err := rl.WaitForCapacity(context.Background(), "wait-negative", -1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "negative tokens")
})
}
func TestNilUsageStats(t *testing.T) {
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
rl := newTestLimiter(t)
model := "nil-cansend"
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10}
rl.State[model] = nil
assert.NotPanics(t, func() {
assert.True(t, rl.CanSend(model, 10))
})
assert.NotNil(t, rl.State[model])
})
t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) {
rl := newTestLimiter(t)
model := "nil-record"
rl.State[model] = nil
assert.NotPanics(t, func() {
rl.RecordUsage(model, 10, 10)
})
require.NotNil(t, rl.State[model])
assert.Equal(t, 1, rl.State[model].DayCount)
})
t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) {
rl := newTestLimiter(t)
rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}
rl.State["nil-stats"] = nil
rl.State["nil-all-stats"] = nil
assert.NotPanics(t, func() {
stats := rl.Stats("nil-stats")
assert.Equal(t, 1, stats.MaxRPM)
assert.Equal(t, 0, stats.TPM)
})
assert.NotPanics(t, func() {
all := rl.AllStats()
assert.Contains(t, all, "nil-stats")
assert.Contains(t, all, "nil-all-stats")
})
})
} }
// --- Phase 0: Stats --- // --- Phase 0: Stats ---
@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) {
_, exists := rl.State[model] _, exists := rl.State[model]
return !exists return !exists
}, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned") }, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned")
t.Run("non-positive interval is a safe no-op", func(t *testing.T) {
rl := newTestLimiter(t)
rl.State["still-here"] = &UsageStats{
Requests: []time.Time{time.Now().Add(-2 * time.Minute)},
}
assert.NotPanics(t, func() {
stop := rl.BackgroundPrune(0)
stop()
})
assert.Contains(t, rl.State, "still-here")
})
} }
// --- Phase 0: CountTokens (with mock HTTP server) --- // --- Phase 0: CountTokens (with mock HTTP server) ---
@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) {
assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key")) assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key"))
assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath())
var body struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
require.Len(t, body.Contents, 1)
require.Len(t, body.Contents[0].Parts, 1)
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}) require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
})) }))
defer server.Close() defer server.Close()
// For testing purposes, we would need to make the base URL configurable. tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello")
// Since we're just checking the signature and basic logic, we test the error paths. require.NoError(t, err)
assert.Equal(t, 42, tokens)
}) })
t.Run("API error returns error", func(t *testing.T) { t.Run("model name is path escaped", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
assert.Empty(t, r.URL.RawQuery)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
}))
defer server.Close()
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello")
require.NoError(t, err)
assert.Equal(t, 7, tokens)
})
t.Run("API error body is truncated", func(t *testing.T) {
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error": "invalid API key"}`) _, err := fmt.Fprint(w, largeBody)
require.NoError(t, err)
})) }))
defer server.Close() defer server.Close()
_, err := CountTokens(context.Background(), "fake-key", "test-model", "hello") _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
assert.Error(t, err, "should fail with invalid API endpoint") require.Error(t, err)
assert.Contains(t, err.Error(), "api error status 401")
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
assert.Contains(t, err.Error(), "...")
}) })
t.Run("empty model is rejected before request", func(t *testing.T) {
_, err := CountTokens(context.Background(), "fake-key", "", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("invalid base URL returns error", func(t *testing.T) {
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("base URL without host returns error", func(t *testing.T) {
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("invalid JSON response returns error", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write([]byte(`{"totalTokens":`))
require.NoError(t, err)
}))
defer server.Close()
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "decode response")
})
t.Run("error body read failures are returned", func(t *testing.T) {
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: io.NopCloser(errReader{}),
Header: make(http.Header),
}, nil
}),
}
_, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "read error body")
})
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
}))
defer server.Close()
originalClient := http.DefaultClient
http.DefaultClient = server.Client()
defer func() {
http.DefaultClient = originalClient
}()
tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello")
require.NoError(t, err)
assert.Equal(t, 11, tokens)
})
}
func TestPersistSkipsNilState(t *testing.T) {
path := filepath.Join(t.TempDir(), "nil-state.yaml")
rl, err := New()
require.NoError(t, err)
rl.filePath = path
rl.State["nil-model"] = nil
require.NoError(t, rl.Persist())
rl2, err := New()
require.NoError(t, err)
rl2.filePath = path
require.NoError(t, rl2.Load())
assert.NotContains(t, rl2.State, "nil-model")
}
func TestTokenTotals(t *testing.T) {
maxInt := int(^uint(0) >> 1)
assert.Equal(t, 25, safeTokenSum(-100, 25))
assert.Equal(t, maxInt, safeTokenSum(maxInt, 1))
assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}}))
} }
// --- Phase 0: Benchmarks --- // --- Phase 0: Benchmarks ---
@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) {
assert.Equal(t, 5, q.MaxRPM) assert.Equal(t, 5, q.MaxRPM)
assert.Equal(t, 50000, q.MaxTPM) assert.Equal(t, 50000, q.MaxTPM)
}) })
t.Run("invalid backend returns error", func(t *testing.T) {
_, err := NewWithConfig(Config{
Backend: "bogus",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown backend")
})
t.Run("default YAML path uses home directory", func(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", "")
t.Setenv("home", "")
rl, err := NewWithConfig(Config{})
require.NoError(t, err)
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
})
} }
func TestNewBackwardCompatibility(t *testing.T) { func TestNewBackwardCompatibility(t *testing.T) {

View file

@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error {
return nil return nil
} }
// saveQuotas upserts all quotas into the quotas table. // saveQuotas writes a complete quota snapshot to the quotas table.
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
} }
defer tx.Rollback() defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
VALUES (?, ?, ?, ?) return coreerr.E("ratelimit.saveQuotas", "clear", err)
ON CONFLICT(model) DO UPDATE SET
max_rpm = excluded.max_rpm,
max_tpm = excluded.max_tpm,
max_rpd = excluded.max_rpd`)
if err != nil {
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
}
defer stmt.Close()
for model, q := range quotas {
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
}
} }
return tx.Commit() if err := insertQuotas(tx, quotas); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveQuotas")
} }
// loadQuotas reads all rows from the quotas table. // loadQuotas reads all rows from the quotas table.
@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
return result, nil return result, nil
} }
// saveSnapshot writes quotas and state as a single atomic snapshot.
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
tx, err := s.db.Begin()
if err != nil {
return coreerr.E("ratelimit.saveSnapshot", "begin", err)
}
defer tx.Rollback()
if err := clearSnapshotTables(tx, true); err != nil {
return err
}
if err := insertQuotas(tx, quotas); err != nil {
return err
}
if err := insertState(tx, state); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveSnapshot")
}
// saveState writes all usage state to SQLite in a single transaction. // saveState writes all usage state to SQLite in a single transaction.
// It uses a truncate-and-insert approach for simplicity in this version, // It uses a truncate-and-insert approach for simplicity in this version,
// but ensures atomicity via a single transaction. // but ensures atomicity via a single transaction.
@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
} }
defer tx.Rollback() defer tx.Rollback()
// Clear existing state in the transaction. if err := clearSnapshotTables(tx, false); err != nil {
return err
}
if err := insertState(tx, state); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveState")
}
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
if includeQuotas {
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err)
}
}
if _, err := tx.Exec("DELETE FROM requests"); err != nil { if _, err := tx.Exec("DELETE FROM requests"); err != nil {
return coreerr.E("ratelimit.saveState", "clear requests", err) return coreerr.E("ratelimit.saveState", "clear requests", err)
} }
@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
if _, err := tx.Exec("DELETE FROM daily"); err != nil { if _, err := tx.Exec("DELETE FROM daily"); err != nil {
return coreerr.E("ratelimit.saveState", "clear daily", err) return coreerr.E("ratelimit.saveState", "clear daily", err)
} }
return nil
}
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
if err != nil {
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
}
defer stmt.Close()
for model, q := range quotas {
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
}
}
return nil
}
func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)") reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveState", "prepare requests", err) return coreerr.E("ratelimit.saveState", "prepare requests", err)
@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
defer dayStmt.Close() defer dayStmt.Close()
for model, stats := range state { for model, stats := range state {
if stats == nil {
continue
}
for _, t := range stats.Requests { for _, t := range stats.Requests {
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil { if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err) return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err) return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
} }
} }
return nil
}
func commitTx(tx *sql.Tx, scope string) error {
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return coreerr.E("ratelimit.saveState", "commit", err) return coreerr.E(scope, "commit", err)
} }
return nil return nil
} }

View file

@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) {
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)") assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
} }
func TestConfigBackendSQLite_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
rl, err := NewWithConfig(Config{
Backend: backendSQLite,
FilePath: dbPath,
Quotas: map[string]ModelQuota{
"backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
},
})
require.NoError(t, err)
defer rl.Close()
require.NotNil(t, rl.sqlite)
rl.RecordUsage("backend-model", 10, 10)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(dbPath)
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
}
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", "")
t.Setenv("home", "")
rl, err := NewWithConfig(Config{
Backend: backendSQLite,
})
require.NoError(t, err)
defer rl.Close()
require.NotNil(t, rl.sqlite)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
}
// --- Phase 2: MigrateYAMLToSQLite --- // --- Phase 2: MigrateYAMLToSQLite ---
func TestMigrateYAMLToSQLite_Good(t *testing.T) { func TestMigrateYAMLToSQLite_Good(t *testing.T) {
@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
}) })
} }
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
sqlitePath := filepath.Join(tmpDir, "atomic.db")
now := time.Now().UTC()
store, err := newSQLiteStore(sqlitePath)
require.NoError(t, err)
originalQuotas := map[string]ModelQuota{
"old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
}
originalState := map[string]*UsageStats{
"old-model": {
Requests: []time.Time{now},
Tokens: []TokenEntry{{Time: now, Count: 9}},
DayStart: now,
DayCount: 1,
},
}
require.NoError(t, store.saveSnapshot(originalQuotas, originalState))
_, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
require.NoError(t, err)
require.NoError(t, store.close())
migrated := &RateLimiter{
Quotas: map[string]ModelQuota{
"new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30},
},
State: map[string]*UsageStats{
"new-model": {
Requests: []time.Time{now.Add(5 * time.Second)},
Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}},
DayStart: now.Add(5 * time.Second),
DayCount: 2,
},
},
}
data, err := yaml.Marshal(migrated)
require.NoError(t, err)
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
require.Error(t, err)
store, err = newSQLiteStore(sqlitePath)
require.NoError(t, err)
defer store.close()
quotas, err := store.loadQuotas()
require.NoError(t, err)
assert.Equal(t, originalQuotas, quotas)
state, err := store.loadState()
require.NoError(t, err)
require.Contains(t, state, "old-model")
assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount)
assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count)
assert.NotContains(t, state, "new-model")
}
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
tmpDir := t.TempDir() tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "full.yaml") yamlPath := filepath.Join(tmpDir, "full.yaml")
@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
assert.Equal(t, 5, custom.MaxRPM) assert.Equal(t, 5, custom.MaxRPM)
} }
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "replace.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
},
})
require.NoError(t, err)
rl.RecordUsage("model-a", 10, 10)
require.NoError(t, rl.Persist())
delete(rl.Quotas, "model-a")
rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
rl.Reset("")
rl.RecordUsage("model-b", 5, 5)
require.NoError(t, rl.Persist())
require.NoError(t, rl.Close())
rl2, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini},
})
require.NoError(t, err)
defer rl2.Close()
rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99}
require.NoError(t, rl2.Load())
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
assert.NotContains(t, rl2.Quotas, "model-a")
assert.Contains(t, rl2.Quotas, "model-b")
assert.NotContains(t, rl2.State, "stale-memory")
assert.NotContains(t, rl2.State, "model-a")
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
}
func TestSQLitePersistAtomic_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
},
})
require.NoError(t, err)
rl.RecordUsage("old-model", 10, 10)
require.NoError(t, rl.Persist())
_, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
require.NoError(t, err)
delete(rl.Quotas, "old-model")
rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
rl.Reset("")
rl.RecordUsage("new-model", 50, 50)
err = rl.Persist()
require.Error(t, err)
require.NoError(t, rl.Close())
rl2, err := NewWithSQLiteConfig(dbPath, Config{})
require.NoError(t, err)
defer rl2.Close()
require.NoError(t, rl2.Load())
assert.Contains(t, rl2.Quotas, "old-model")
assert.NotContains(t, rl2.Quotas, "new-model")
assert.Equal(t, 1, rl2.Stats("old-model").RPD)
assert.Equal(t, 0, rl2.Stats("new-model").RPD)
}
// --- Phase 2: Benchmark --- // --- Phase 2: Benchmark ---
func BenchmarkSQLitePersist(b *testing.B) { func BenchmarkSQLitePersist(b *testing.B) {