dev #12

Merged
Snider merged 8 commits from dev into main 2026-03-23 20:40:57 +00:00
8 changed files with 869 additions and 133 deletions

35
docs/api-contract.md Normal file
View file

@ -0,0 +1,35 @@
# API Contract
Test coverage is marked `yes` when the symbol is exercised by the existing test suite in `ratelimit_test.go`, `sqlite_test.go`, `error_test.go`, or `iter_test.go`.
| Kind | Name | Signature | Description | Test coverage |
| --- | --- | --- | --- | --- |
| Type | `Provider` | `type Provider string` | Identifies an LLM provider used to select default quota profiles. | yes |
| Type | `ModelQuota` | `type ModelQuota struct { MaxRPM int; MaxTPM int; MaxRPD int }` | Declares per-model RPM, TPM, and RPD limits; `0` means unlimited. | yes |
| Type | `ProviderProfile` | `type ProviderProfile struct { Provider Provider; Models map[string]ModelQuota }` | Bundles a provider identifier with its default model quota map. | yes |
| Type | `Config` | `type Config struct { FilePath string; Backend string; Quotas map[string]ModelQuota; Providers []Provider }` | Configures limiter initialisation, persistence settings, explicit quotas, and provider defaults. | yes |
| Type | `TokenEntry` | `type TokenEntry struct { Time time.Time; Count int }` | Records a timestamped token-usage event. | yes |
| Type | `UsageStats` | `type UsageStats struct { Requests []time.Time; Tokens []TokenEntry; DayStart time.Time; DayCount int }` | Stores per-model sliding-window request and token history plus rolling daily usage state. | yes |
| Type | `RateLimiter` | `type RateLimiter struct { Quotas map[string]ModelQuota; State map[string]*UsageStats }` | Manages quotas, usage state, persistence, and concurrency across models. | yes |
| Type | `ModelStats` | `type ModelStats struct { RPM int; MaxRPM int; TPM int; MaxTPM int; RPD int; MaxRPD int; DayStart time.Time }` | Represents a snapshot of current usage and configured limits for a model. | yes |
| Function | `DefaultProfiles` | `func DefaultProfiles() map[Provider]ProviderProfile` | Returns the built-in quota profiles for the supported providers. | yes |
| Function | `New` | `func New() (*RateLimiter, error)` | Creates a new limiter with Gemini defaults for backward-compatible YAML-backed usage. | yes |
| Function | `NewWithConfig` | `func NewWithConfig(cfg Config) (*RateLimiter, error)` | Creates a YAML-backed limiter from explicit configuration, defaulting to Gemini when config is empty. | yes |
| Function | `NewWithSQLite` | `func NewWithSQLite(dbPath string) (*RateLimiter, error)` | Creates a SQLite-backed limiter with Gemini defaults and opens or creates the database. | yes |
| Function | `NewWithSQLiteConfig` | `func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error)` | Creates a SQLite-backed limiter from explicit configuration and always uses SQLite persistence. | yes |
| Function | `MigrateYAMLToSQLite` | `func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error` | Migrates quotas and usage state from a YAML state file into a SQLite database. | yes |
| Function | `CountTokens` | `func CountTokens(ctx context.Context, apiKey, model, text string) (int, error)` | Calls the Gemini `countTokens` API for a prompt and returns the reported token count. | yes |
| Method | `SetQuota` | `func (rl *RateLimiter) SetQuota(model string, quota ModelQuota)` | Sets or replaces a model quota at runtime. | yes |
| Method | `AddProvider` | `func (rl *RateLimiter) AddProvider(provider Provider)` | Loads a built-in provider profile and overwrites any matching model quotas. | yes |
| Method | `Load` | `func (rl *RateLimiter) Load() error` | Loads quotas and usage state from YAML or SQLite into memory. | yes |
| Method | `Persist` | `func (rl *RateLimiter) Persist() error` | Persists a snapshot of quotas and usage state to YAML or SQLite. | yes |
| Method | `BackgroundPrune` | `func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func()` | Starts periodic pruning of expired usage state and returns a stop function. | yes |
| Method | `CanSend` | `func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool` | Reports whether a request with the estimated token count fits within current limits. | yes |
| Method | `RecordUsage` | `func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)` | Records a successful request into the sliding-window and daily counters. | yes |
| Method | `WaitForCapacity` | `func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error` | Blocks until `CanSend` succeeds or the context is cancelled. | yes |
| Method | `Reset` | `func (rl *RateLimiter) Reset(model string)` | Clears usage state for one model or for all models when `model` is empty. | yes |
| Method | `Models` | `func (rl *RateLimiter) Models() iter.Seq[string]` | Returns a sorted iterator of all model names known from quotas or state. | yes |
| Method | `Iter` | `func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats]` | Returns a sorted iterator of model names paired with current stats snapshots. | yes |
| Method | `Stats` | `func (rl *RateLimiter) Stats(model string) ModelStats` | Returns current stats for a single model after pruning expired usage. | yes |
| Method | `AllStats` | `func (rl *RateLimiter) AllStats() map[string]ModelStats` | Returns stats snapshots for every tracked model. | yes |
| Method | `Close` | `func (rl *RateLimiter) Close() error` | Closes SQLite resources for SQLite-backed limiters and is a no-op for YAML-backed limiters. | yes |

View file

@ -0,0 +1,75 @@
<!-- SPDX-License-Identifier: EUPL-1.2 -->
# Convention Drift Check
Date: 2026-03-23
`CODEX.md` was not present anywhere under `/workspace`, so this check used
`CLAUDE.md`, `docs/development.md`, and the current tree.
## stdlib -> core.*
- `CLAUDE.md:65` still documents `forge.lthn.ai/core/go-io`, while the direct
dependency in `go.mod:6` is `dappco.re/go/core/io`.
- `CLAUDE.md:66` still documents `forge.lthn.ai/core/go-log`, while the direct
dependency in `go.mod:7` is `dappco.re/go/core/log`.
- `docs/development.md:175` still mandates `fmt.Errorf("ratelimit.Function: what: %w", err)`,
but the implementation and repo guidance use `coreerr.E(...)`
(`ratelimit.go:19`, `sqlite.go:8`, `CLAUDE.md:31`).
- `docs/development.md:195` omits the `core.*` direct dependencies entirely;
`go.mod:6-7` now declares both `dappco.re/go/core/io` and
`dappco.re/go/core/log`.
- `docs/index.md:102` likewise omits the `core.*` direct dependencies that are
present in `go.mod:6-7`.
## UK English
- `README.md:2` uses `License` in the badge alt text.
- `CONTRIBUTING.md:34` uses `License` as the section heading.
## Missing Tests
- `docs/development.md:150` says current coverage is `95.1%` and the floor is
`95%`, but `go test -coverprofile=/tmp/convention-cover.out ./...` currently
reports `94.4%`.
- `docs/history.md:28` still records coverage as `95.1%`; the current tree is
below that figure at `94.4%`.
- `docs/history.md:66` and `docs/development.md:157` say the
`os.UserHomeDir()` branch in `NewWithConfig()` is untestable, but
`error_test.go:648` now exercises that path.
- `ratelimit.go:202` (`Load`) is only `90.0%` covered.
- `ratelimit.go:242` (`Persist`) is only `90.0%` covered.
- `ratelimit.go:650` (`CountTokens`) is only `71.4%` covered.
- `sqlite.go:20` (`newSQLiteStore`) is only `64.3%` covered.
- `sqlite.go:110` (`loadQuotas`) is only `92.9%` covered.
- `sqlite.go:194` (`loadState`) is only `88.6%` covered.
- `ratelimit_test.go:746` starts a mock server for `CountTokens`, but the test
contains no assertion that exercises the success path.
- `iter_test.go:108` starts a second mock server for `CountTokens`, but again
does not exercise the mocked path.
- `error_test.go:42` defines `TestSQLiteInitErrors`, but the `WAL pragma failure`
subtest is still an empty placeholder.
## SPDX Headers
- `CLAUDE.md:1` is missing an SPDX header.
- `CONTRIBUTING.md:1` is missing an SPDX header.
- `README.md:1` is missing an SPDX header.
- `docs/architecture.md:1` is missing an SPDX header.
- `docs/development.md:1` is missing an SPDX header.
- `docs/history.md:1` is missing an SPDX header.
- `docs/index.md:1` is missing an SPDX header.
- `error_test.go:1` is missing an SPDX header.
- `go.mod:1` is missing an SPDX header.
- `iter_test.go:1` is missing an SPDX header.
- `ratelimit.go:1` is missing an SPDX header.
- `ratelimit_test.go:1` is missing an SPDX header.
- `sqlite.go:1` is missing an SPDX header.
- `sqlite_test.go:1` is missing an SPDX header.
## Verification
- `go test ./...`
- `go test -coverprofile=/tmp/convention-cover.out ./...`
- `go test -race ./...`
- `go vet ./...`

View file

@ -0,0 +1,44 @@
# Security Attack Vector Mapping
Scope: external inputs that cross into this package from callers, persisted storage, or the network. This is a mapping only; it does not propose or apply fixes.
Note: `CODEX.md` was not present anywhere under `/workspace` during this scan, so conventions were taken from `CLAUDE.md` and the existing repository layout.
## Caller-Controlled API Inputs
| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector |
| --- | --- | --- | --- | --- | --- |
| `NewWithConfig(cfg Config)` | `ratelimit.go:145` | Caller-controlled `cfg.FilePath`, `cfg.Providers`, `cfg.Quotas` | `cfg.FilePath` is stored in `rl.filePath` and later used by `Load()` / `Persist()`; `cfg.Providers` selects built-in profiles; `cfg.Quotas` is copied straight into `rl.Quotas` | Empty `FilePath` defaults to `~/.core/ratelimits.yaml`; unknown providers are ignored; `cfg.Backend` is not read here; quota values and model keys are copied verbatim | Untrusted `FilePath` can later steer YAML reads and writes to arbitrary local paths because the package uses unsandboxed `coreio.Local`; negative quota values act like "unlimited" because enforcement only checks `> 0`; very large maps or model strings can drive memory and disk growth |
| `(*RateLimiter).SetQuota(model string, quota ModelQuota)` | `ratelimit.go:183` | Caller-controlled `model` and `quota` | Direct write to `rl.Quotas[model]`; later consumed by `CanSend()`, `Stats()`, `Persist()`, and SQLite/YAML persistence | None beyond Go type checking | Negative quota fields disable enforcement for that dimension; extreme values can skew blocking behaviour; arbitrary model names expand the in-memory and persisted keyspace |
| `(*RateLimiter).AddProvider(provider Provider)` | `ratelimit.go:191` | Caller-controlled `provider` selector | Looks up `DefaultProfiles()[provider]` and overwrites matching entries in `rl.Quotas` via `maps.Copy` | Unknown providers are silently ignored; no authorisation or policy guard in this package | If a higher-level service exposes this call, an attacker can overwrite stricter runtime quotas for a provider's models with the shipped defaults and relax the intended rate-limit policy |
| `(*RateLimiter).BackgroundPrune(interval time.Duration)` | `ratelimit.go:328` | Caller-controlled `interval` | Passed to `time.NewTicker(interval)` and drives a background goroutine that repeatedly locks and prunes state | None | `interval <= 0` causes a panic; very small intervals can create CPU and lock-contention DoS; repeated calls without using the returned cancel function leak goroutines |
| `(*RateLimiter).CanSend(model string, estimatedTokens int)` | `ratelimit.go:350` | Caller-controlled `model` and `estimatedTokens` | `model` indexes `rl.Quotas` / `rl.State`; `estimatedTokens` is added to the current token total before the TPM comparison | Unknown models are allowed immediately; no non-negative or range checks on `estimatedTokens` | Passing an unconfigured model name bypasses throttling entirely; negative or overflowed token values can undercount the TPM check and permit oversend |
| `(*RateLimiter).RecordUsage(model string, promptTokens, outputTokens int)` | `ratelimit.go:396` | Caller-controlled `model`, `promptTokens`, `outputTokens` | Creates or updates `rl.State[model]`; stores `promptTokens + outputTokens` in the token window and increments `DayCount` | None | Arbitrary model names create unbounded state that will later persist to YAML/SQLite; negative or overflowed token totals poison accounting and can reduce future TPM totals below the real usage |
| `(*RateLimiter).WaitForCapacity(ctx context.Context, model string, tokens int)` | `ratelimit.go:414` | Caller-controlled `ctx`, `model`, `tokens` | Calls `CanSend(model, tokens)` once per second until capacity is available or `ctx.Done()` fires | No direct validation; relies on downstream `CanSend()` and caller-supplied context cancellation | Inherits the unknown-model and negative-token bypasses from `CanSend()`; repeated calls with long-lived contexts can accumulate goroutines and lock pressure |
| `(*RateLimiter).Reset(model string)` | `ratelimit.go:433` | Caller-controlled `model` | `model == ""` replaces the entire `rl.State` map; otherwise `delete(rl.State, model)` | Empty string is treated as a wildcard reset | If reachable by an untrusted actor, an empty string clears all rate-limit history and targeted resets erase throttling state for chosen models |
| `(*RateLimiter).Stats(model string)` | `ratelimit.go:484` | Caller-controlled `model` | Prunes `rl.State[model]`, reads `rl.Quotas[model]`, and returns a usage snapshot | None | If exposed through a service boundary, it discloses per-model quota ceilings and live usage counts that can help an attacker tune evasion or timing |
| `NewWithSQLite(dbPath string)` | `ratelimit.go:567` | Caller-controlled `dbPath` | Thin wrapper that forwards `dbPath` into `NewWithSQLiteConfig()` and then `newSQLiteStore()` | No additional validation in the wrapper | Untrusted `dbPath` can steer database creation/opening to unintended local filesystem locations, including companion `-wal` and `-shm` files |
| `NewWithSQLiteConfig(dbPath string, cfg Config)` | `ratelimit.go:576` | Caller-controlled `dbPath`, `cfg.Providers`, `cfg.Quotas` | `dbPath` goes straight to `newSQLiteStore()`; provider and quota inputs are copied into `rl.Quotas` exactly as in `NewWithConfig()` | `cfg.Backend` is ignored; unknown providers are ignored; no path, range, or size checks | Combines arbitrary database-path selection with the same quota poisoning risks as `NewWithConfig()` and `SetQuota()` |
## Filesystem And YAML Inputs
| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector |
| --- | --- | --- | --- | --- | --- |
| `(*RateLimiter).Load()` (YAML backend) | `ratelimit.go:210` | File bytes read from `rl.filePath` | `coreio.Local.Read(rl.filePath)` feeds `yaml.Unmarshal(..., rl)`, which overwrites `rl.Quotas` and `rl.State` | Missing file is ignored; YAML syntax and type mismatches return an error; no semantic checks on counts, limits, model names, or slice sizes | A malicious YAML file can inject negative quotas or counters to bypass enforcement, preload very large maps/slices for memory DoS, or replace the in-memory policy/state with attacker-chosen values |
| `MigrateYAMLToSQLite(yamlPath, sqlitePath string)` | `ratelimit.go:617` | Caller-controlled `yamlPath` and `sqlitePath`, plus YAML file bytes from `yamlPath` | Reads YAML with `coreio.Local.Read(yamlPath)`, unmarshals into a temporary `RateLimiter`, then opens `sqlitePath` and writes the imported quotas/state into SQLite | Read/open errors and YAML syntax/type errors are surfaced; no path restrictions and no semantic validation of imported quotas/state | Untrusted paths enable arbitrary local file reads and database creation/clobbering; attacker-controlled YAML can permanently seed the SQLite backend with quota-bypass values or oversized state |
## SQLite Inputs
| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector |
| --- | --- | --- | --- | --- | --- |
| `newSQLiteStore(dbPath string)` | `sqlite.go:20` | Caller-controlled database path passed from `NewWithSQLite*()` or `MigrateYAMLToSQLite()` | `sql.Open("sqlite", dbPath)` opens or creates the database, then applies PRAGMAs and creates tables and indexes | Only driver/open/PRAGMA/schema errors are checked; there is no allowlist, path normalisation, or sandboxing in this package | An attacker-chosen `dbPath` can redirect state into unintended files or directories and create matching SQLite sidecar files, which is useful for tampering, data placement, or storage-exhaustion attacks |
| `(*RateLimiter).Load()` (SQLite backend via `loadSQLite()`) | `ratelimit.go:223` | SQLite content reachable through the already-open `rl.sqlite` handle | `loadQuotas()` and `loadState()` results are copied into `rl.Quotas` and `rl.State`; loaded quotas override in-memory defaults | Only lower-level scan/query errors are checked; no semantic validation after load | A tampered SQLite database can override intended quotas/state, including negative limits and poisoned counters, before any enforcement call runs |
| `(*sqliteStore).loadQuotas()` | `sqlite.go:110` | Rows from the `quotas` table (`model`, `max_rpm`, `max_tpm`, `max_rpd`) | Scanned into `map[string]ModelQuota`, then copied into `rl.Quotas` by `loadSQLite()` | SQL scan and row iteration errors only; no range or length checks | Negative or extreme values disable or destabilise later quota enforcement; a large number of rows or large model strings can cause memory growth |
| `(*sqliteStore).loadState()` | `sqlite.go:194` | Rows from the `daily`, `requests`, and `tokens` tables | Scanned into `UsageStats` maps/slices and then copied into `rl.State` by `loadSQLite()` | SQL scan and row iteration errors only; no bounds, ordering, or semantic checks on timestamps and counts | Crafted counts or timestamps can poison later `CanSend()` / `Stats()` results; oversized tables can drive memory and CPU exhaustion during load |
## Network Inputs
| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector |
| --- | --- | --- | --- | --- | --- |
| `CountTokens(ctx, apiKey, model, text)` (request construction) | `ratelimit.go:650` | Caller-controlled `ctx`, `apiKey`, `model`, `text` | `model` is interpolated directly into the Google API URL path; `text` is marshalled into JSON request body; `apiKey` goes into `x-goog-api-key`; `ctx` governs request lifetime | JSON marshalling must succeed; `http.NewRequestWithContext()` rejects some malformed URLs; there is no path escaping, length check, or output-size cap on the request body | Untrusted prompt text is exfiltrated to a remote API; very large text can consume memory/bandwidth; unescaped model strings can alter the path or query on the fixed Google host; repeated calls burn external quota |
| `CountTokens(ctx, apiKey, model, text)` (response handling) | `ratelimit.go:681` | Remote HTTP status, response body, and JSON `totalTokens` value from `generativelanguage.googleapis.com` | Non-200 bodies are read fully with `io.ReadAll()` and embedded into the returned error; 200 responses are decoded into `result.TotalTokens` and returned to the caller | Checks for HTTP 200 and JSON decode errors only; no response-body size limit and no sanity check on `TotalTokens` | A very large error body can cause memory pressure and log/telemetry pollution; a negative or extreme `totalTokens` value would be returned unchanged and could poison downstream rate-limit accounting if the caller trusts it |

View file

@ -31,13 +31,19 @@ func TestIterators(t *testing.T) {
assert.Contains(t, models, "model-a")
assert.Contains(t, models, "model-b")
assert.Contains(t, models, "model-c")
// Check sorting of our specific models
foundA, foundB, foundC := -1, -1, -1
for i, m := range models {
if m == "model-a" { foundA = i }
if m == "model-b" { foundB = i }
if m == "model-c" { foundC = i }
if m == "model-a" {
foundA = i
}
if m == "model-b" {
foundB = i
}
if m == "model-c" {
foundC = i
}
}
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
})
@ -57,9 +63,15 @@ func TestIterators(t *testing.T) {
// Check sorting
foundA, foundB, foundC := -1, -1, -1
for i, m := range models {
if m == "model-a" { foundA = i }
if m == "model-b" { foundB = i }
if m == "model-c" { foundC = i }
if m == "model-a" {
foundA = i
}
if m == "model-b" {
foundB = i
}
if m == "model-c" {
foundC = i
}
}
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
})
@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) {
}
func TestCountTokensFull(t *testing.T) {
t.Run("invalid URL/network error", func(t *testing.T) {
// Using an invalid character in model name to trigger URL error or similar
_, err := CountTokens(context.Background(), "key", "invalid model", "text")
t.Run("empty model is rejected", func(t *testing.T) {
_, err := CountTokens(context.Background(), "key", "", "text")
assert.Error(t, err)
})
@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) {
}))
defer server.Close()
// We can't easily override the URL in CountTokens without changing the code,
// but we can test the logic if we make it slightly more testable.
// For now, I've already updated ratelimit_test.go with some of this.
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text")
require.Error(t, err)
assert.Contains(t, err.Error(), "status 400")
})
t.Run("context cancelled", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := CountTokens(ctx, "key", "model", "text")
_, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text")
assert.Error(t, err)
assert.Contains(t, err.Error(), "do request")
})

View file

@ -9,9 +9,11 @@ import (
"iter"
"maps"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
@ -34,6 +36,16 @@ const (
ProviderLocal Provider = "local"
)
const (
defaultStateDirName = ".core"
defaultYAMLStateFile = "ratelimits.yaml"
defaultSQLiteStateFile = "ratelimits.db"
backendYAML = "yaml"
backendSQLite = "sqlite"
countTokensErrorBodyLimit = 8 * 1024
countTokensSuccessBodyLimit = 1 * 1024 * 1024
)
// ModelQuota defines the rate limits for a specific model.
type ModelQuota struct {
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
@ -143,39 +155,30 @@ func New() (*RateLimiter, error) {
// NewWithConfig creates a RateLimiter from explicit configuration.
// If no providers or quotas are specified, Gemini defaults are used.
func NewWithConfig(cfg Config) (*RateLimiter, error) {
backend, err := normaliseBackend(cfg.Backend)
if err != nil {
return nil, err
}
filePath := cfg.FilePath
if filePath == "" {
home, err := os.UserHomeDir()
filePath, err = defaultStatePath(backend)
if err != nil {
return nil, err
}
filePath = filepath.Join(home, ".core", "ratelimits.yaml")
}
rl := &RateLimiter{
Quotas: make(map[string]ModelQuota),
State: make(map[string]*UsageStats),
filePath: filePath,
}
// Load provider profiles
profiles := DefaultProfiles()
providers := cfg.Providers
// If nothing specified at all, default to Gemini
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
if backend == backendSQLite {
if cfg.FilePath == "" {
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
}
}
return NewWithSQLiteConfig(filePath, cfg)
}
// Merge explicit quotas on top (allows overrides)
maps.Copy(rl.Quotas, cfg.Quotas)
rl := newConfiguredRateLimiter(cfg)
rl.filePath = filePath
return rl, nil
}
@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error {
if err != nil {
return err
}
// Merge loaded quotas (loaded quotas override in-memory defaults).
maps.Copy(rl.Quotas, quotas)
// Persisted quotas are authoritative when present; otherwise keep config defaults.
if len(quotas) > 0 {
rl.Quotas = maps.Clone(quotas)
}
state, err := rl.sqlite.loadState()
if err != nil {
return err
}
// Replace in-memory state with persisted state.
maps.Copy(rl.State, state)
// Replace in-memory state with the persisted snapshot.
rl.State = state
return nil
}
@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error {
quotas := maps.Clone(rl.Quotas)
state := make(map[string]*UsageStats, len(rl.State))
for k, v := range rl.State {
if v == nil {
continue
}
state[k] = &UsageStats{
Requests: slices.Clone(v.Requests),
Tokens: slices.Clone(v.Tokens),
@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error {
rl.mu.Unlock()
if sqlite != nil {
if err := sqlite.saveQuotas(quotas); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite quotas", err)
}
if err := sqlite.saveState(state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite state", err)
if err := sqlite.saveSnapshot(quotas, state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
}
return nil
}
@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) {
if !ok {
return
}
if stats == nil {
delete(rl.State, model)
return
}
now := time.Now()
window := now.Add(-1 * time.Minute)
@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) {
// BackgroundPrune starts a goroutine that periodically prunes all model states.
// It returns a function to stop the pruner.
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
if interval <= 0 {
return func() {}
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
ticker := time.NewTicker(interval)
@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
// CanSend checks if a request can be sent without violating limits.
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
if estimatedTokens < 0 {
return false
}
rl.mu.Lock()
defer rl.mu.Unlock()
@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
// Check TPM
if quota.MaxTPM > 0 {
currentTokens := 0
for _, t := range stats.Tokens {
currentTokens += t.Count
}
if currentTokens+estimatedTokens > quota.MaxTPM {
currentTokens := totalTokenCount(stats.Tokens)
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
return false
}
}
@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
}
now := time.Now()
totalTokens := safeTokenSum(promptTokens, outputTokens)
stats.Requests = append(stats.Requests, now)
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens})
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens})
stats.DayCount++
}
// WaitForCapacity blocks until capacity is available or context is cancelled.
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
if tokens < 0 {
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
result[m] = ModelStats{}
}
now := time.Now()
window := now.Add(-1 * time.Minute)
for m := range result {
// Prune inline
if s, ok := rl.State[m]; ok {
s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool {
return !t.After(window)
})
s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool {
return !t.Time.After(window)
})
if now.Sub(s.DayStart) >= 24*time.Hour {
s.DayStart = now
s.DayCount = 0
}
}
rl.prune(m)
ms := ModelStats{}
if q, ok := rl.Quotas[m]; ok {
@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
ms.MaxTPM = q.MaxTPM
ms.MaxRPD = q.MaxRPD
}
if s, ok := rl.State[m]; ok {
if s, ok := rl.State[m]; ok && s != nil {
ms.RPM = len(s.Requests)
ms.RPD = s.DayCount
ms.DayStart = s.DayStart
for _, t := range s.Tokens {
ms.TPM += t.Count
}
ms.TPM = totalTokenCount(s.Tokens)
}
result[m] = ms
}
@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
return nil, err
}
rl := &RateLimiter{
Quotas: make(map[string]ModelQuota),
State: make(map[string]*UsageStats),
sqlite: store,
}
// Load provider profiles.
profiles := DefaultProfiles()
providers := cfg.Providers
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
maps.Copy(rl.Quotas, cfg.Quotas)
rl := newConfiguredRateLimiter(cfg)
rl.sqlite = store
return rl, nil
}
@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
}
defer store.close()
if rl.Quotas != nil {
if err := store.saveQuotas(rl.Quotas); err != nil {
return err
}
}
if rl.State != nil {
if err := store.saveState(rl.State); err != nil {
return err
}
if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil {
return err
}
return nil
}
// CountTokens calls the Google API to count tokens for a prompt.
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model)
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
}
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
requestURL, err := countTokensURL(baseURL, model)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
}
reqBody := map[string]any{
"contents": []any{
@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey)
resp, err := http.DefaultClient.Do(req)
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil)
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
}
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
}
var result struct {
TotalTokens int `json:"totalTokens"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
}
return result.TotalTokens, nil
}
func newConfiguredRateLimiter(cfg Config) *RateLimiter {
rl := &RateLimiter{
Quotas: make(map[string]ModelQuota),
State: make(map[string]*UsageStats),
}
applyConfig(rl, cfg)
return rl
}
func applyConfig(rl *RateLimiter, cfg Config) {
profiles := DefaultProfiles()
providers := cfg.Providers
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
maps.Copy(rl.Quotas, cfg.Quotas)
}
func normaliseBackend(backend string) (string, error) {
switch strings.ToLower(strings.TrimSpace(backend)) {
case "", backendYAML:
return backendYAML, nil
case backendSQLite:
return backendSQLite, nil
default:
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
}
}
func defaultStatePath(backend string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
fileName := defaultYAMLStateFile
if backend == backendSQLite {
fileName = defaultSQLiteStateFile
}
return filepath.Join(home, defaultStateDirName, fileName), nil
}
func safeTokenSum(a, b int) int {
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
}
func totalTokenCount(tokens []TokenEntry) int {
return safeTokenTotal(tokens)
}
func safeTokenTotal(tokens []TokenEntry) int {
const maxInt = int(^uint(0) >> 1)
total := 0
for _, token := range tokens {
count := token.Count
if count < 0 {
continue
}
if total > maxInt-count {
return maxInt
}
total += count
}
return total
}
func countTokensURL(baseURL, model string) (string, error) {
if strings.TrimSpace(model) == "" {
return "", fmt.Errorf("empty model")
}
parsed, err := url.Parse(baseURL)
if err != nil {
return "", err
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid base url")
}
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
}
func readLimitedBody(r io.Reader, limit int64) (string, error) {
body, err := io.ReadAll(io.LimitReader(r, limit+1))
if err != nil {
return "", err
}
truncated := int64(len(body)) > limit
if truncated {
body = body[:limit]
}
result := string(body)
if truncated {
result += "..."
}
return result, nil
}

View file

@ -4,10 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter {
return rl
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
type errReader struct{}
func (errReader) Read([]byte) (int, error) {
return 0, io.ErrUnexpectedEOF
}
// --- Phase 0: CanSend boundary conditions ---
func TestCanSend(t *testing.T) {
@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) {
rl.RecordUsage(model, 50, 50) // exactly 100 tokens
assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit")
})
t.Run("negative estimated tokens are rejected", func(t *testing.T) {
rl := newTestLimiter(t)
model := "negative-est"
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100}
rl.RecordUsage(model, 50, 50)
assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits")
})
}
// --- Phase 0: Sliding window / prune tests ---
@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) {
assert.Len(t, stats.Tokens, 2)
assert.Equal(t, 6, stats.DayCount)
})
t.Run("negative token inputs are clamped to zero", func(t *testing.T) {
rl := newTestLimiter(t)
model := "record-negative"
rl.RecordUsage(model, -100, 25)
stats := rl.State[model]
require.NotNil(t, stats)
assert.Len(t, stats.Requests, 1)
assert.Len(t, stats.Tokens, 1)
assert.Equal(t, 25, stats.Tokens[0].Count)
})
}
// --- Phase 0: Reset ---
@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) {
err := rl.WaitForCapacity(ctx, model, 10)
assert.Error(t, err, "should return error for already-cancelled context")
})
t.Run("negative tokens return error", func(t *testing.T) {
rl := newTestLimiter(t)
err := rl.WaitForCapacity(context.Background(), "wait-negative", -1)
assert.Error(t, err)
assert.Contains(t, err.Error(), "negative tokens")
})
}
func TestNilUsageStats(t *testing.T) {
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
rl := newTestLimiter(t)
model := "nil-cansend"
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10}
rl.State[model] = nil
assert.NotPanics(t, func() {
assert.True(t, rl.CanSend(model, 10))
})
assert.NotNil(t, rl.State[model])
})
t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) {
rl := newTestLimiter(t)
model := "nil-record"
rl.State[model] = nil
assert.NotPanics(t, func() {
rl.RecordUsage(model, 10, 10)
})
require.NotNil(t, rl.State[model])
assert.Equal(t, 1, rl.State[model].DayCount)
})
t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) {
rl := newTestLimiter(t)
rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}
rl.State["nil-stats"] = nil
rl.State["nil-all-stats"] = nil
assert.NotPanics(t, func() {
stats := rl.Stats("nil-stats")
assert.Equal(t, 1, stats.MaxRPM)
assert.Equal(t, 0, stats.TPM)
})
assert.NotPanics(t, func() {
all := rl.AllStats()
assert.Contains(t, all, "nil-stats")
assert.Contains(t, all, "nil-all-stats")
})
})
}
// --- Phase 0: Stats ---
@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) {
_, exists := rl.State[model]
return !exists
}, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned")
t.Run("non-positive interval is a safe no-op", func(t *testing.T) {
rl := newTestLimiter(t)
rl.State["still-here"] = &UsageStats{
Requests: []time.Time{time.Now().Add(-2 * time.Minute)},
}
assert.NotPanics(t, func() {
stop := rl.BackgroundPrune(0)
stop()
})
assert.Contains(t, rl.State, "still-here")
})
}
// --- Phase 0: CountTokens (with mock HTTP server) ---
@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) {
assert.Equal(t, http.MethodPost, r.Method)
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key"))
assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath())
var body struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
require.Len(t, body.Contents, 1)
require.Len(t, body.Contents[0].Parts, 1)
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
}))
defer server.Close()
// For testing purposes, we would need to make the base URL configurable.
// Since we're just checking the signature and basic logic, we test the error paths.
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello")
require.NoError(t, err)
assert.Equal(t, 42, tokens)
})
t.Run("API error returns error", func(t *testing.T) {
t.Run("model name is path escaped", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
assert.Empty(t, r.URL.RawQuery)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
}))
defer server.Close()
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello")
require.NoError(t, err)
assert.Equal(t, 7, tokens)
})
t.Run("API error body is truncated", func(t *testing.T) {
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
fmt.Fprint(w, `{"error": "invalid API key"}`)
_, err := fmt.Fprint(w, largeBody)
require.NoError(t, err)
}))
defer server.Close()
_, err := CountTokens(context.Background(), "fake-key", "test-model", "hello")
assert.Error(t, err, "should fail with invalid API endpoint")
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "api error status 401")
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
assert.Contains(t, err.Error(), "...")
})
t.Run("empty model is rejected before request", func(t *testing.T) {
_, err := CountTokens(context.Background(), "fake-key", "", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("invalid base URL returns error", func(t *testing.T) {
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("base URL without host returns error", func(t *testing.T) {
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "build url")
})
t.Run("invalid JSON response returns error", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write([]byte(`{"totalTokens":`))
require.NoError(t, err)
}))
defer server.Close()
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "decode response")
})
t.Run("error body read failures are returned", func(t *testing.T) {
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusBadGateway,
Body: io.NopCloser(errReader{}),
Header: make(http.Header),
}, nil
}),
}
_, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "read error body")
})
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
}))
defer server.Close()
originalClient := http.DefaultClient
http.DefaultClient = server.Client()
defer func() {
http.DefaultClient = originalClient
}()
tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello")
require.NoError(t, err)
assert.Equal(t, 11, tokens)
})
}
func TestPersistSkipsNilState(t *testing.T) {
path := filepath.Join(t.TempDir(), "nil-state.yaml")
rl, err := New()
require.NoError(t, err)
rl.filePath = path
rl.State["nil-model"] = nil
require.NoError(t, rl.Persist())
rl2, err := New()
require.NoError(t, err)
rl2.filePath = path
require.NoError(t, rl2.Load())
assert.NotContains(t, rl2.State, "nil-model")
}
func TestTokenTotals(t *testing.T) {
maxInt := int(^uint(0) >> 1)
assert.Equal(t, 25, safeTokenSum(-100, 25))
assert.Equal(t, maxInt, safeTokenSum(maxInt, 1))
assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}}))
}
// --- Phase 0: Benchmarks ---
@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) {
assert.Equal(t, 5, q.MaxRPM)
assert.Equal(t, 50000, q.MaxTPM)
})
t.Run("invalid backend returns error", func(t *testing.T) {
_, err := NewWithConfig(Config{
Backend: "bogus",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown backend")
})
t.Run("default YAML path uses home directory", func(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", "")
t.Setenv("home", "")
rl, err := NewWithConfig(Config{})
require.NoError(t, err)
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
})
}
func TestNewBackwardCompatibility(t *testing.T) {

View file

@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error {
return nil
}
// saveQuotas upserts all quotas into the quotas table.
// saveQuotas writes a complete quota snapshot to the quotas table.
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
tx, err := s.db.Begin()
if err != nil {
@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
}
defer tx.Rollback()
stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd)
VALUES (?, ?, ?, ?)
ON CONFLICT(model) DO UPDATE SET
max_rpm = excluded.max_rpm,
max_tpm = excluded.max_tpm,
max_rpd = excluded.max_rpd`)
if err != nil {
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
}
defer stmt.Close()
for model, q := range quotas {
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
}
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
return coreerr.E("ratelimit.saveQuotas", "clear", err)
}
return tx.Commit()
if err := insertQuotas(tx, quotas); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveQuotas")
}
// loadQuotas reads all rows from the quotas table.
@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
return result, nil
}
// saveSnapshot writes quotas and state as a single atomic snapshot.
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
tx, err := s.db.Begin()
if err != nil {
return coreerr.E("ratelimit.saveSnapshot", "begin", err)
}
defer tx.Rollback()
if err := clearSnapshotTables(tx, true); err != nil {
return err
}
if err := insertQuotas(tx, quotas); err != nil {
return err
}
if err := insertState(tx, state); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveSnapshot")
}
// saveState writes all usage state to SQLite in a single transaction.
// It uses a truncate-and-insert approach for simplicity in this version,
// but ensures atomicity via a single transaction.
@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
}
defer tx.Rollback()
// Clear existing state in the transaction.
if err := clearSnapshotTables(tx, false); err != nil {
return err
}
if err := insertState(tx, state); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveState")
}
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
if includeQuotas {
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err)
}
}
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
return coreerr.E("ratelimit.saveState", "clear requests", err)
}
@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
return coreerr.E("ratelimit.saveState", "clear daily", err)
}
return nil
}
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
if err != nil {
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
}
defer stmt.Close()
for model, q := range quotas {
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
}
}
return nil
}
func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
if err != nil {
return coreerr.E("ratelimit.saveState", "prepare requests", err)
@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
defer dayStmt.Close()
for model, stats := range state {
if stats == nil {
continue
}
for _, t := range stats.Requests {
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
}
}
return nil
}
func commitTx(tx *sql.Tx, scope string) error {
if err := tx.Commit(); err != nil {
return coreerr.E("ratelimit.saveState", "commit", err)
return coreerr.E(scope, "commit", err)
}
return nil
}

View file

@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) {
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
}
func TestConfigBackendSQLite_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
rl, err := NewWithConfig(Config{
Backend: backendSQLite,
FilePath: dbPath,
Quotas: map[string]ModelQuota{
"backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
},
})
require.NoError(t, err)
defer rl.Close()
require.NotNil(t, rl.sqlite)
rl.RecordUsage("backend-model", 10, 10)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(dbPath)
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
}
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", "")
t.Setenv("home", "")
rl, err := NewWithConfig(Config{
Backend: backendSQLite,
})
require.NoError(t, err)
defer rl.Close()
require.NotNil(t, rl.sqlite)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
}
// --- Phase 2: MigrateYAMLToSQLite ---
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
})
}
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
sqlitePath := filepath.Join(tmpDir, "atomic.db")
now := time.Now().UTC()
store, err := newSQLiteStore(sqlitePath)
require.NoError(t, err)
originalQuotas := map[string]ModelQuota{
"old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
}
originalState := map[string]*UsageStats{
"old-model": {
Requests: []time.Time{now},
Tokens: []TokenEntry{{Time: now, Count: 9}},
DayStart: now,
DayCount: 1,
},
}
require.NoError(t, store.saveSnapshot(originalQuotas, originalState))
_, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
require.NoError(t, err)
require.NoError(t, store.close())
migrated := &RateLimiter{
Quotas: map[string]ModelQuota{
"new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30},
},
State: map[string]*UsageStats{
"new-model": {
Requests: []time.Time{now.Add(5 * time.Second)},
Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}},
DayStart: now.Add(5 * time.Second),
DayCount: 2,
},
},
}
data, err := yaml.Marshal(migrated)
require.NoError(t, err)
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
require.Error(t, err)
store, err = newSQLiteStore(sqlitePath)
require.NoError(t, err)
defer store.close()
quotas, err := store.loadQuotas()
require.NoError(t, err)
assert.Equal(t, originalQuotas, quotas)
state, err := store.loadState()
require.NoError(t, err)
require.Contains(t, state, "old-model")
assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount)
assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count)
assert.NotContains(t, state, "new-model")
}
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "full.yaml")
@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
assert.Equal(t, 5, custom.MaxRPM)
}
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "replace.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
},
})
require.NoError(t, err)
rl.RecordUsage("model-a", 10, 10)
require.NoError(t, rl.Persist())
delete(rl.Quotas, "model-a")
rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
rl.Reset("")
rl.RecordUsage("model-b", 5, 5)
require.NoError(t, rl.Persist())
require.NoError(t, rl.Close())
rl2, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini},
})
require.NoError(t, err)
defer rl2.Close()
rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99}
require.NoError(t, rl2.Load())
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
assert.NotContains(t, rl2.Quotas, "model-a")
assert.Contains(t, rl2.Quotas, "model-b")
assert.NotContains(t, rl2.State, "stale-memory")
assert.NotContains(t, rl2.State, "model-a")
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
}
func TestSQLitePersistAtomic_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
},
})
require.NoError(t, err)
rl.RecordUsage("old-model", 10, 10)
require.NoError(t, rl.Persist())
_, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
require.NoError(t, err)
delete(rl.Quotas, "old-model")
rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
rl.Reset("")
rl.RecordUsage("new-model", 50, 50)
err = rl.Persist()
require.Error(t, err)
require.NoError(t, rl.Close())
rl2, err := NewWithSQLiteConfig(dbPath, Config{})
require.NoError(t, err)
defer rl2.Close()
require.NoError(t, rl2.Load())
assert.Contains(t, rl2.Quotas, "old-model")
assert.NotContains(t, rl2.Quotas, "new-model")
assert.Equal(t, 1, rl2.Stats("old-model").RPD)
assert.Equal(t, 0, rl2.Stats("new-model").RPD)
}
// --- Phase 2: Benchmark ---
func BenchmarkSQLitePersist(b *testing.B) {