fix(ratelimit): harden audit edge cases
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
0db35c4ce9
commit
d1c90b937d
5 changed files with 715 additions and 133 deletions
39
iter_test.go
39
iter_test.go
|
|
@ -31,13 +31,19 @@ func TestIterators(t *testing.T) {
|
||||||
assert.Contains(t, models, "model-a")
|
assert.Contains(t, models, "model-a")
|
||||||
assert.Contains(t, models, "model-b")
|
assert.Contains(t, models, "model-b")
|
||||||
assert.Contains(t, models, "model-c")
|
assert.Contains(t, models, "model-c")
|
||||||
|
|
||||||
// Check sorting of our specific models
|
// Check sorting of our specific models
|
||||||
foundA, foundB, foundC := -1, -1, -1
|
foundA, foundB, foundC := -1, -1, -1
|
||||||
for i, m := range models {
|
for i, m := range models {
|
||||||
if m == "model-a" { foundA = i }
|
if m == "model-a" {
|
||||||
if m == "model-b" { foundB = i }
|
foundA = i
|
||||||
if m == "model-c" { foundC = i }
|
}
|
||||||
|
if m == "model-b" {
|
||||||
|
foundB = i
|
||||||
|
}
|
||||||
|
if m == "model-c" {
|
||||||
|
foundC = i
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
|
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
|
||||||
})
|
})
|
||||||
|
|
@ -57,9 +63,15 @@ func TestIterators(t *testing.T) {
|
||||||
// Check sorting
|
// Check sorting
|
||||||
foundA, foundB, foundC := -1, -1, -1
|
foundA, foundB, foundC := -1, -1, -1
|
||||||
for i, m := range models {
|
for i, m := range models {
|
||||||
if m == "model-a" { foundA = i }
|
if m == "model-a" {
|
||||||
if m == "model-b" { foundB = i }
|
foundA = i
|
||||||
if m == "model-c" { foundC = i }
|
}
|
||||||
|
if m == "model-b" {
|
||||||
|
foundB = i
|
||||||
|
}
|
||||||
|
if m == "model-c" {
|
||||||
|
foundC = i
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
|
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
|
||||||
})
|
})
|
||||||
|
|
@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCountTokensFull(t *testing.T) {
|
func TestCountTokensFull(t *testing.T) {
|
||||||
t.Run("invalid URL/network error", func(t *testing.T) {
|
t.Run("empty model is rejected", func(t *testing.T) {
|
||||||
// Using an invalid character in model name to trigger URL error or similar
|
_, err := CountTokens(context.Background(), "key", "", "text")
|
||||||
_, err := CountTokens(context.Background(), "key", "invalid model", "text")
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) {
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// We can't easily override the URL in CountTokens without changing the code,
|
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text")
|
||||||
// but we can test the logic if we make it slightly more testable.
|
require.Error(t, err)
|
||||||
// For now, I've already updated ratelimit_test.go with some of this.
|
assert.Contains(t, err.Error(), "status 400")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("context cancelled", func(t *testing.T) {
|
t.Run("context cancelled", func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
cancel()
|
cancel()
|
||||||
_, err := CountTokens(ctx, "key", "model", "text")
|
_, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "do request")
|
assert.Contains(t, err.Error(), "do request")
|
||||||
})
|
})
|
||||||
|
|
|
||||||
287
ratelimit.go
287
ratelimit.go
|
|
@ -9,9 +9,11 @@ import (
|
||||||
"iter"
|
"iter"
|
||||||
"maps"
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -34,6 +36,16 @@ const (
|
||||||
ProviderLocal Provider = "local"
|
ProviderLocal Provider = "local"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultStateDirName = ".core"
|
||||||
|
defaultYAMLStateFile = "ratelimits.yaml"
|
||||||
|
defaultSQLiteStateFile = "ratelimits.db"
|
||||||
|
backendYAML = "yaml"
|
||||||
|
backendSQLite = "sqlite"
|
||||||
|
countTokensErrorBodyLimit = 8 * 1024
|
||||||
|
countTokensSuccessBodyLimit = 1 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
// ModelQuota defines the rate limits for a specific model.
|
// ModelQuota defines the rate limits for a specific model.
|
||||||
type ModelQuota struct {
|
type ModelQuota struct {
|
||||||
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
|
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
|
||||||
|
|
@ -143,39 +155,30 @@ func New() (*RateLimiter, error) {
|
||||||
// NewWithConfig creates a RateLimiter from explicit configuration.
|
// NewWithConfig creates a RateLimiter from explicit configuration.
|
||||||
// If no providers or quotas are specified, Gemini defaults are used.
|
// If no providers or quotas are specified, Gemini defaults are used.
|
||||||
func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
||||||
|
backend, err := normaliseBackend(cfg.Backend)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
filePath := cfg.FilePath
|
filePath := cfg.FilePath
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
home, err := os.UserHomeDir()
|
filePath, err = defaultStatePath(backend)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
filePath = filepath.Join(home, ".core", "ratelimits.yaml")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rl := &RateLimiter{
|
if backend == backendSQLite {
|
||||||
Quotas: make(map[string]ModelQuota),
|
if cfg.FilePath == "" {
|
||||||
State: make(map[string]*UsageStats),
|
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||||
filePath: filePath,
|
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load provider profiles
|
|
||||||
profiles := DefaultProfiles()
|
|
||||||
providers := cfg.Providers
|
|
||||||
|
|
||||||
// If nothing specified at all, default to Gemini
|
|
||||||
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
|
||||||
providers = []Provider{ProviderGemini}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range providers {
|
|
||||||
if profile, ok := profiles[p]; ok {
|
|
||||||
maps.Copy(rl.Quotas, profile.Models)
|
|
||||||
}
|
}
|
||||||
|
return NewWithSQLiteConfig(filePath, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Merge explicit quotas on top (allows overrides)
|
rl := newConfiguredRateLimiter(cfg)
|
||||||
maps.Copy(rl.Quotas, cfg.Quotas)
|
rl.filePath = filePath
|
||||||
|
|
||||||
return rl, nil
|
return rl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Merge loaded quotas (loaded quotas override in-memory defaults).
|
// Persisted quotas are authoritative when present; otherwise keep config defaults.
|
||||||
maps.Copy(rl.Quotas, quotas)
|
if len(quotas) > 0 {
|
||||||
|
rl.Quotas = maps.Clone(quotas)
|
||||||
|
}
|
||||||
|
|
||||||
state, err := rl.sqlite.loadState()
|
state, err := rl.sqlite.loadState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// Replace in-memory state with persisted state.
|
// Replace in-memory state with the persisted snapshot.
|
||||||
maps.Copy(rl.State, state)
|
rl.State = state
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error {
|
||||||
quotas := maps.Clone(rl.Quotas)
|
quotas := maps.Clone(rl.Quotas)
|
||||||
state := make(map[string]*UsageStats, len(rl.State))
|
state := make(map[string]*UsageStats, len(rl.State))
|
||||||
for k, v := range rl.State {
|
for k, v := range rl.State {
|
||||||
|
if v == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
state[k] = &UsageStats{
|
state[k] = &UsageStats{
|
||||||
Requests: slices.Clone(v.Requests),
|
Requests: slices.Clone(v.Requests),
|
||||||
Tokens: slices.Clone(v.Tokens),
|
Tokens: slices.Clone(v.Tokens),
|
||||||
|
|
@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error {
|
||||||
rl.mu.Unlock()
|
rl.mu.Unlock()
|
||||||
|
|
||||||
if sqlite != nil {
|
if sqlite != nil {
|
||||||
if err := sqlite.saveQuotas(quotas); err != nil {
|
if err := sqlite.saveSnapshot(quotas, state); err != nil {
|
||||||
return coreerr.E("ratelimit.Persist", "sqlite quotas", err)
|
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
|
||||||
}
|
|
||||||
if err := sqlite.saveState(state); err != nil {
|
|
||||||
return coreerr.E("ratelimit.Persist", "sqlite state", err)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) {
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if stats == nil {
|
||||||
|
delete(rl.State, model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
window := now.Add(-1 * time.Minute)
|
window := now.Add(-1 * time.Minute)
|
||||||
|
|
@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) {
|
||||||
// BackgroundPrune starts a goroutine that periodically prunes all model states.
|
// BackgroundPrune starts a goroutine that periodically prunes all model states.
|
||||||
// It returns a function to stop the pruner.
|
// It returns a function to stop the pruner.
|
||||||
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
||||||
|
if interval <= 0 {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
|
|
@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
||||||
|
|
||||||
// CanSend checks if a request can be sent without violating limits.
|
// CanSend checks if a request can be sent without violating limits.
|
||||||
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||||
|
if estimatedTokens < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
rl.mu.Lock()
|
rl.mu.Lock()
|
||||||
defer rl.mu.Unlock()
|
defer rl.mu.Unlock()
|
||||||
|
|
||||||
|
|
@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||||
|
|
||||||
// Check TPM
|
// Check TPM
|
||||||
if quota.MaxTPM > 0 {
|
if quota.MaxTPM > 0 {
|
||||||
currentTokens := 0
|
currentTokens := totalTokenCount(stats.Tokens)
|
||||||
for _, t := range stats.Tokens {
|
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
|
||||||
currentTokens += t.Count
|
|
||||||
}
|
|
||||||
if currentTokens+estimatedTokens > quota.MaxTPM {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
totalTokens := safeTokenSum(promptTokens, outputTokens)
|
||||||
stats.Requests = append(stats.Requests, now)
|
stats.Requests = append(stats.Requests, now)
|
||||||
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens})
|
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens})
|
||||||
stats.DayCount++
|
stats.DayCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||||
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||||
|
if tokens < 0 {
|
||||||
|
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
|
||||||
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
|
@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||||
result[m] = ModelStats{}
|
result[m] = ModelStats{}
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
window := now.Add(-1 * time.Minute)
|
|
||||||
|
|
||||||
for m := range result {
|
for m := range result {
|
||||||
// Prune inline
|
rl.prune(m)
|
||||||
if s, ok := rl.State[m]; ok {
|
|
||||||
s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool {
|
|
||||||
return !t.After(window)
|
|
||||||
})
|
|
||||||
s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool {
|
|
||||||
return !t.Time.After(window)
|
|
||||||
})
|
|
||||||
|
|
||||||
if now.Sub(s.DayStart) >= 24*time.Hour {
|
|
||||||
s.DayStart = now
|
|
||||||
s.DayCount = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ms := ModelStats{}
|
ms := ModelStats{}
|
||||||
if q, ok := rl.Quotas[m]; ok {
|
if q, ok := rl.Quotas[m]; ok {
|
||||||
|
|
@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||||
ms.MaxTPM = q.MaxTPM
|
ms.MaxTPM = q.MaxTPM
|
||||||
ms.MaxRPD = q.MaxRPD
|
ms.MaxRPD = q.MaxRPD
|
||||||
}
|
}
|
||||||
if s, ok := rl.State[m]; ok {
|
if s, ok := rl.State[m]; ok && s != nil {
|
||||||
ms.RPM = len(s.Requests)
|
ms.RPM = len(s.Requests)
|
||||||
ms.RPD = s.DayCount
|
ms.RPD = s.DayCount
|
||||||
ms.DayStart = s.DayStart
|
ms.DayStart = s.DayStart
|
||||||
for _, t := range s.Tokens {
|
ms.TPM = totalTokenCount(s.Tokens)
|
||||||
ms.TPM += t.Count
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
result[m] = ms
|
result[m] = ms
|
||||||
}
|
}
|
||||||
|
|
@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
rl := &RateLimiter{
|
rl := newConfiguredRateLimiter(cfg)
|
||||||
Quotas: make(map[string]ModelQuota),
|
rl.sqlite = store
|
||||||
State: make(map[string]*UsageStats),
|
|
||||||
sqlite: store,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load provider profiles.
|
|
||||||
profiles := DefaultProfiles()
|
|
||||||
providers := cfg.Providers
|
|
||||||
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
|
||||||
providers = []Provider{ProviderGemini}
|
|
||||||
}
|
|
||||||
for _, p := range providers {
|
|
||||||
if profile, ok := profiles[p]; ok {
|
|
||||||
maps.Copy(rl.Quotas, profile.Models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
maps.Copy(rl.Quotas, cfg.Quotas)
|
|
||||||
|
|
||||||
return rl, nil
|
return rl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
||||||
}
|
}
|
||||||
defer store.close()
|
defer store.close()
|
||||||
|
|
||||||
if rl.Quotas != nil {
|
if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil {
|
||||||
if err := store.saveQuotas(rl.Quotas); err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if rl.State != nil {
|
|
||||||
if err := store.saveState(rl.State); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountTokens calls the Google API to count tokens for a prompt.
|
// CountTokens calls the Google API to count tokens for a prompt.
|
||||||
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
||||||
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model)
|
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
|
||||||
|
requestURL, err := countTokensURL(baseURL, model)
|
||||||
|
if err != nil {
|
||||||
|
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
|
||||||
|
}
|
||||||
|
|
||||||
reqBody := map[string]any{
|
reqBody := map[string]any{
|
||||||
"contents": []any{
|
"contents": []any{
|
||||||
|
|
@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
||||||
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
|
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
|
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
req.Header.Set("x-goog-api-key", apiKey)
|
req.Header.Set("x-goog-api-key", apiKey)
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
|
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
|
||||||
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil)
|
if err != nil {
|
||||||
|
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
|
||||||
|
}
|
||||||
|
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var result struct {
|
var result struct {
|
||||||
TotalTokens int `json:"totalTokens"`
|
TotalTokens int `json:"totalTokens"`
|
||||||
}
|
}
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
|
||||||
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
|
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.TotalTokens, nil
|
return result.TotalTokens, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newConfiguredRateLimiter(cfg Config) *RateLimiter {
|
||||||
|
rl := &RateLimiter{
|
||||||
|
Quotas: make(map[string]ModelQuota),
|
||||||
|
State: make(map[string]*UsageStats),
|
||||||
|
}
|
||||||
|
applyConfig(rl, cfg)
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyConfig(rl *RateLimiter, cfg Config) {
|
||||||
|
profiles := DefaultProfiles()
|
||||||
|
providers := cfg.Providers
|
||||||
|
|
||||||
|
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
||||||
|
providers = []Provider{ProviderGemini}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range providers {
|
||||||
|
if profile, ok := profiles[p]; ok {
|
||||||
|
maps.Copy(rl.Quotas, profile.Models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maps.Copy(rl.Quotas, cfg.Quotas)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normaliseBackend(backend string) (string, error) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(backend)) {
|
||||||
|
case "", backendYAML:
|
||||||
|
return backendYAML, nil
|
||||||
|
case backendSQLite:
|
||||||
|
return backendSQLite, nil
|
||||||
|
default:
|
||||||
|
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultStatePath(backend string) (string, error) {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
fileName := defaultYAMLStateFile
|
||||||
|
if backend == backendSQLite {
|
||||||
|
fileName = defaultSQLiteStateFile
|
||||||
|
}
|
||||||
|
|
||||||
|
return filepath.Join(home, defaultStateDirName, fileName), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeTokenSum(a, b int) int {
|
||||||
|
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
|
||||||
|
}
|
||||||
|
|
||||||
|
func totalTokenCount(tokens []TokenEntry) int {
|
||||||
|
return safeTokenTotal(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeTokenTotal(tokens []TokenEntry) int {
|
||||||
|
const maxInt = int(^uint(0) >> 1)
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, token := range tokens {
|
||||||
|
count := token.Count
|
||||||
|
if count < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if total > maxInt-count {
|
||||||
|
return maxInt
|
||||||
|
}
|
||||||
|
total += count
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func countTokensURL(baseURL, model string) (string, error) {
|
||||||
|
if strings.TrimSpace(model) == "" {
|
||||||
|
return "", fmt.Errorf("empty model")
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := url.Parse(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if parsed.Scheme == "" || parsed.Host == "" {
|
||||||
|
return "", fmt.Errorf("invalid base url")
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLimitedBody(r io.Reader, limit int64) (string, error) {
|
||||||
|
body, err := io.ReadAll(io.LimitReader(r, limit+1))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
truncated := int64(len(body)) > limit
|
||||||
|
if truncated {
|
||||||
|
body = body[:limit]
|
||||||
|
}
|
||||||
|
|
||||||
|
result := string(body)
|
||||||
|
if truncated {
|
||||||
|
result += "..."
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,10 +4,12 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter {
|
||||||
return rl
|
return rl
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||||
|
|
||||||
|
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return f(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
type errReader struct{}
|
||||||
|
|
||||||
|
func (errReader) Read([]byte) (int, error) {
|
||||||
|
return 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
|
||||||
// --- Phase 0: CanSend boundary conditions ---
|
// --- Phase 0: CanSend boundary conditions ---
|
||||||
|
|
||||||
func TestCanSend(t *testing.T) {
|
func TestCanSend(t *testing.T) {
|
||||||
|
|
@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) {
|
||||||
rl.RecordUsage(model, 50, 50) // exactly 100 tokens
|
rl.RecordUsage(model, 50, 50) // exactly 100 tokens
|
||||||
assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit")
|
assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("negative estimated tokens are rejected", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
model := "negative-est"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100}
|
||||||
|
rl.RecordUsage(model, 50, 50)
|
||||||
|
|
||||||
|
assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 0: Sliding window / prune tests ---
|
// --- Phase 0: Sliding window / prune tests ---
|
||||||
|
|
@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) {
|
||||||
assert.Len(t, stats.Tokens, 2)
|
assert.Len(t, stats.Tokens, 2)
|
||||||
assert.Equal(t, 6, stats.DayCount)
|
assert.Equal(t, 6, stats.DayCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("negative token inputs are clamped to zero", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
model := "record-negative"
|
||||||
|
|
||||||
|
rl.RecordUsage(model, -100, 25)
|
||||||
|
|
||||||
|
stats := rl.State[model]
|
||||||
|
require.NotNil(t, stats)
|
||||||
|
assert.Len(t, stats.Requests, 1)
|
||||||
|
assert.Len(t, stats.Tokens, 1)
|
||||||
|
assert.Equal(t, 25, stats.Tokens[0].Count)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 0: Reset ---
|
// --- Phase 0: Reset ---
|
||||||
|
|
@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) {
|
||||||
err := rl.WaitForCapacity(ctx, model, 10)
|
err := rl.WaitForCapacity(ctx, model, 10)
|
||||||
assert.Error(t, err, "should return error for already-cancelled context")
|
assert.Error(t, err, "should return error for already-cancelled context")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("negative tokens return error", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
err := rl.WaitForCapacity(context.Background(), "wait-negative", -1)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "negative tokens")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNilUsageStats(t *testing.T) {
|
||||||
|
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
model := "nil-cansend"
|
||||||
|
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10}
|
||||||
|
rl.State[model] = nil
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
assert.True(t, rl.CanSend(model, 10))
|
||||||
|
})
|
||||||
|
assert.NotNil(t, rl.State[model])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
model := "nil-record"
|
||||||
|
rl.State[model] = nil
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
rl.RecordUsage(model, 10, 10)
|
||||||
|
})
|
||||||
|
require.NotNil(t, rl.State[model])
|
||||||
|
assert.Equal(t, 1, rl.State[model].DayCount)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}
|
||||||
|
rl.State["nil-stats"] = nil
|
||||||
|
rl.State["nil-all-stats"] = nil
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
stats := rl.Stats("nil-stats")
|
||||||
|
assert.Equal(t, 1, stats.MaxRPM)
|
||||||
|
assert.Equal(t, 0, stats.TPM)
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
all := rl.AllStats()
|
||||||
|
assert.Contains(t, all, "nil-stats")
|
||||||
|
assert.Contains(t, all, "nil-all-stats")
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 0: Stats ---
|
// --- Phase 0: Stats ---
|
||||||
|
|
@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) {
|
||||||
_, exists := rl.State[model]
|
_, exists := rl.State[model]
|
||||||
return !exists
|
return !exists
|
||||||
}, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned")
|
}, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned")
|
||||||
|
|
||||||
|
t.Run("non-positive interval is a safe no-op", func(t *testing.T) {
|
||||||
|
rl := newTestLimiter(t)
|
||||||
|
rl.State["still-here"] = &UsageStats{
|
||||||
|
Requests: []time.Time{time.Now().Add(-2 * time.Minute)},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
stop := rl.BackgroundPrune(0)
|
||||||
|
stop()
|
||||||
|
})
|
||||||
|
assert.Contains(t, rl.State, "still-here")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 0: CountTokens (with mock HTTP server) ---
|
// --- Phase 0: CountTokens (with mock HTTP server) ---
|
||||||
|
|
@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) {
|
||||||
assert.Equal(t, http.MethodPost, r.Method)
|
assert.Equal(t, http.MethodPost, r.Method)
|
||||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||||
assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key"))
|
assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key"))
|
||||||
|
assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath())
|
||||||
|
|
||||||
|
var body struct {
|
||||||
|
Contents []struct {
|
||||||
|
Parts []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"parts"`
|
||||||
|
} `json:"contents"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||||
|
require.Len(t, body.Contents, 1)
|
||||||
|
require.Len(t, body.Contents[0].Parts, 1)
|
||||||
|
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})
|
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
// For testing purposes, we would need to make the base URL configurable.
|
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello")
|
||||||
// Since we're just checking the signature and basic logic, we test the error paths.
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 42, tokens)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("API error returns error", func(t *testing.T) {
|
t.Run("model name is path escaped", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
|
||||||
|
assert.Empty(t, r.URL.RawQuery)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 7, tokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("API error body is truncated", func(t *testing.T) {
|
||||||
|
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
|
||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusUnauthorized)
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
fmt.Fprint(w, `{"error": "invalid API key"}`)
|
_, err := fmt.Fprint(w, largeBody)
|
||||||
|
require.NoError(t, err)
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
_, err := CountTokens(context.Background(), "fake-key", "test-model", "hello")
|
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||||
assert.Error(t, err, "should fail with invalid API endpoint")
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "api error status 401")
|
||||||
|
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||||
|
assert.Contains(t, err.Error(), "...")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("empty model is rejected before request", func(t *testing.T) {
|
||||||
|
_, err := CountTokens(context.Background(), "fake-key", "", "hello")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "build url")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid base URL returns error", func(t *testing.T) {
|
||||||
|
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "build url")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("base URL without host returns error", func(t *testing.T) {
|
||||||
|
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "build url")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid JSON response returns error", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := w.Write([]byte(`{"totalTokens":`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "decode response")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error body read failures are returned", func(t *testing.T) {
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
Body: io.NopCloser(errReader{}),
|
||||||
|
Header: make(http.Header),
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "read error body")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
originalClient := http.DefaultClient
|
||||||
|
http.DefaultClient = server.Client()
|
||||||
|
defer func() {
|
||||||
|
http.DefaultClient = originalClient
|
||||||
|
}()
|
||||||
|
|
||||||
|
tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 11, tokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistSkipsNilState(t *testing.T) {
|
||||||
|
path := filepath.Join(t.TempDir(), "nil-state.yaml")
|
||||||
|
|
||||||
|
rl, err := New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rl.filePath = path
|
||||||
|
rl.State["nil-model"] = nil
|
||||||
|
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
|
||||||
|
rl2, err := New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rl2.filePath = path
|
||||||
|
require.NoError(t, rl2.Load())
|
||||||
|
assert.NotContains(t, rl2.State, "nil-model")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenTotals(t *testing.T) {
|
||||||
|
maxInt := int(^uint(0) >> 1)
|
||||||
|
|
||||||
|
assert.Equal(t, 25, safeTokenSum(-100, 25))
|
||||||
|
assert.Equal(t, maxInt, safeTokenSum(maxInt, 1))
|
||||||
|
assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}}))
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 0: Benchmarks ---
|
// --- Phase 0: Benchmarks ---
|
||||||
|
|
@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) {
|
||||||
assert.Equal(t, 5, q.MaxRPM)
|
assert.Equal(t, 5, q.MaxRPM)
|
||||||
assert.Equal(t, 50000, q.MaxTPM)
|
assert.Equal(t, 50000, q.MaxTPM)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("invalid backend returns error", func(t *testing.T) {
|
||||||
|
_, err := NewWithConfig(Config{
|
||||||
|
Backend: "bogus",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "unknown backend")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("default YAML path uses home directory", func(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
t.Setenv("HOME", home)
|
||||||
|
t.Setenv("USERPROFILE", "")
|
||||||
|
t.Setenv("home", "")
|
||||||
|
|
||||||
|
rl, err := NewWithConfig(Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewBackwardCompatibility(t *testing.T) {
|
func TestNewBackwardCompatibility(t *testing.T) {
|
||||||
|
|
|
||||||
91
sqlite.go
91
sqlite.go
|
|
@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveQuotas upserts all quotas into the quotas table.
|
// saveQuotas writes a complete quota snapshot to the quotas table.
|
||||||
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd)
|
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||||
VALUES (?, ?, ?, ?)
|
return coreerr.E("ratelimit.saveQuotas", "clear", err)
|
||||||
ON CONFLICT(model) DO UPDATE SET
|
|
||||||
max_rpm = excluded.max_rpm,
|
|
||||||
max_tpm = excluded.max_tpm,
|
|
||||||
max_rpd = excluded.max_rpd`)
|
|
||||||
if err != nil {
|
|
||||||
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
for model, q := range quotas {
|
|
||||||
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
|
||||||
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return tx.Commit()
|
if err := insertQuotas(tx, quotas); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return commitTx(tx, "ratelimit.saveQuotas")
|
||||||
}
|
}
|
||||||
|
|
||||||
// loadQuotas reads all rows from the quotas table.
|
// loadQuotas reads all rows from the quotas table.
|
||||||
|
|
@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// saveSnapshot writes quotas and state as a single atomic snapshot.
|
||||||
|
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return coreerr.E("ratelimit.saveSnapshot", "begin", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
if err := clearSnapshotTables(tx, true); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := insertQuotas(tx, quotas); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := insertState(tx, state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return commitTx(tx, "ratelimit.saveSnapshot")
|
||||||
|
}
|
||||||
|
|
||||||
// saveState writes all usage state to SQLite in a single transaction.
|
// saveState writes all usage state to SQLite in a single transaction.
|
||||||
// It uses a truncate-and-insert approach for simplicity in this version,
|
// It uses a truncate-and-insert approach for simplicity in this version,
|
||||||
// but ensures atomicity via a single transaction.
|
// but ensures atomicity via a single transaction.
|
||||||
|
|
@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
|
|
||||||
// Clear existing state in the transaction.
|
if err := clearSnapshotTables(tx, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := insertState(tx, state); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return commitTx(tx, "ratelimit.saveState")
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
|
||||||
|
if includeQuotas {
|
||||||
|
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||||
|
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
|
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
|
||||||
return coreerr.E("ratelimit.saveState", "clear requests", err)
|
return coreerr.E("ratelimit.saveState", "clear requests", err)
|
||||||
}
|
}
|
||||||
|
|
@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||||||
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
|
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
|
||||||
return coreerr.E("ratelimit.saveState", "clear daily", err)
|
return coreerr.E("ratelimit.saveState", "clear daily", err)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
|
||||||
|
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
|
||||||
|
if err != nil {
|
||||||
|
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
for model, q := range quotas {
|
||||||
|
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
||||||
|
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
||||||
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
|
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return coreerr.E("ratelimit.saveState", "prepare requests", err)
|
return coreerr.E("ratelimit.saveState", "prepare requests", err)
|
||||||
|
|
@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||||||
defer dayStmt.Close()
|
defer dayStmt.Close()
|
||||||
|
|
||||||
for model, stats := range state {
|
for model, stats := range state {
|
||||||
|
if stats == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
for _, t := range stats.Requests {
|
for _, t := range stats.Requests {
|
||||||
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
|
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
|
||||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
|
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
|
||||||
|
|
@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
|
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func commitTx(tx *sql.Tx, scope string) error {
|
||||||
if err := tx.Commit(); err != nil {
|
if err := tx.Commit(); err != nil {
|
||||||
return coreerr.E("ratelimit.saveState", "commit", err)
|
return coreerr.E(scope, "commit", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
174
sqlite_test.go
174
sqlite_test.go
|
|
@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) {
|
||||||
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
|
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConfigBackendSQLite_Good(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
|
||||||
|
rl, err := NewWithConfig(Config{
|
||||||
|
Backend: backendSQLite,
|
||||||
|
FilePath: dbPath,
|
||||||
|
Quotas: map[string]ModelQuota{
|
||||||
|
"backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rl.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, rl.sqlite)
|
||||||
|
rl.RecordUsage("backend-model", 10, 10)
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
|
||||||
|
_, statErr := os.Stat(dbPath)
|
||||||
|
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||||
|
home := t.TempDir()
|
||||||
|
t.Setenv("HOME", home)
|
||||||
|
t.Setenv("USERPROFILE", "")
|
||||||
|
t.Setenv("home", "")
|
||||||
|
|
||||||
|
rl, err := NewWithConfig(Config{
|
||||||
|
Backend: backendSQLite,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rl.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, rl.sqlite)
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
|
||||||
|
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
|
||||||
|
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
|
||||||
|
}
|
||||||
|
|
||||||
// --- Phase 2: MigrateYAMLToSQLite ---
|
// --- Phase 2: MigrateYAMLToSQLite ---
|
||||||
|
|
||||||
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
||||||
|
|
@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
|
||||||
|
sqlitePath := filepath.Join(tmpDir, "atomic.db")
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
store, err := newSQLiteStore(sqlitePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
originalQuotas := map[string]ModelQuota{
|
||||||
|
"old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
|
||||||
|
}
|
||||||
|
originalState := map[string]*UsageStats{
|
||||||
|
"old-model": {
|
||||||
|
Requests: []time.Time{now},
|
||||||
|
Tokens: []TokenEntry{{Time: now, Count: 9}},
|
||||||
|
DayStart: now,
|
||||||
|
DayCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
require.NoError(t, store.saveSnapshot(originalQuotas, originalState))
|
||||||
|
|
||||||
|
_, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily
|
||||||
|
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, store.close())
|
||||||
|
|
||||||
|
migrated := &RateLimiter{
|
||||||
|
Quotas: map[string]ModelQuota{
|
||||||
|
"new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30},
|
||||||
|
},
|
||||||
|
State: map[string]*UsageStats{
|
||||||
|
"new-model": {
|
||||||
|
Requests: []time.Time{now.Add(5 * time.Second)},
|
||||||
|
Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}},
|
||||||
|
DayStart: now.Add(5 * time.Second),
|
||||||
|
DayCount: 2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, err := yaml.Marshal(migrated)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
|
||||||
|
|
||||||
|
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
store, err = newSQLiteStore(sqlitePath)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer store.close()
|
||||||
|
|
||||||
|
quotas, err := store.loadQuotas()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, originalQuotas, quotas)
|
||||||
|
|
||||||
|
state, err := store.loadState()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Contains(t, state, "old-model")
|
||||||
|
assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount)
|
||||||
|
assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count)
|
||||||
|
assert.NotContains(t, state, "new-model")
|
||||||
|
}
|
||||||
|
|
||||||
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
yamlPath := filepath.Join(tmpDir, "full.yaml")
|
yamlPath := filepath.Join(tmpDir, "full.yaml")
|
||||||
|
|
@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
|
||||||
assert.Equal(t, 5, custom.MaxRPM)
|
assert.Equal(t, 5, custom.MaxRPM)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "replace.db")
|
||||||
|
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||||
|
Quotas: map[string]ModelQuota{
|
||||||
|
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl.RecordUsage("model-a", 10, 10)
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
|
||||||
|
delete(rl.Quotas, "model-a")
|
||||||
|
rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
|
||||||
|
rl.Reset("")
|
||||||
|
rl.RecordUsage("model-b", 5, 5)
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
require.NoError(t, rl.Close())
|
||||||
|
|
||||||
|
rl2, err := NewWithSQLiteConfig(dbPath, Config{
|
||||||
|
Providers: []Provider{ProviderGemini},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rl2.Close()
|
||||||
|
|
||||||
|
rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99}
|
||||||
|
require.NoError(t, rl2.Load())
|
||||||
|
|
||||||
|
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
|
||||||
|
assert.NotContains(t, rl2.Quotas, "model-a")
|
||||||
|
assert.Contains(t, rl2.Quotas, "model-b")
|
||||||
|
assert.NotContains(t, rl2.State, "stale-memory")
|
||||||
|
assert.NotContains(t, rl2.State, "model-a")
|
||||||
|
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSQLitePersistAtomic_Good(t *testing.T) {
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
|
||||||
|
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||||
|
Quotas: map[string]ModelQuota{
|
||||||
|
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl.RecordUsage("old-model", 10, 10)
|
||||||
|
require.NoError(t, rl.Persist())
|
||||||
|
|
||||||
|
_, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily
|
||||||
|
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
delete(rl.Quotas, "old-model")
|
||||||
|
rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
|
||||||
|
rl.Reset("")
|
||||||
|
rl.RecordUsage("new-model", 50, 50)
|
||||||
|
|
||||||
|
err = rl.Persist()
|
||||||
|
require.Error(t, err)
|
||||||
|
require.NoError(t, rl.Close())
|
||||||
|
|
||||||
|
rl2, err := NewWithSQLiteConfig(dbPath, Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rl2.Close()
|
||||||
|
require.NoError(t, rl2.Load())
|
||||||
|
|
||||||
|
assert.Contains(t, rl2.Quotas, "old-model")
|
||||||
|
assert.NotContains(t, rl2.Quotas, "new-model")
|
||||||
|
assert.Equal(t, 1, rl2.Stats("old-model").RPD)
|
||||||
|
assert.Equal(t, 0, rl2.Stats("new-model").RPD)
|
||||||
|
}
|
||||||
|
|
||||||
// --- Phase 2: Benchmark ---
|
// --- Phase 2: Benchmark ---
|
||||||
|
|
||||||
func BenchmarkSQLitePersist(b *testing.B) {
|
func BenchmarkSQLitePersist(b *testing.B) {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue