Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
| 1396bee176 |
18 changed files with 491 additions and 1094 deletions
11
CLAUDE.md
11
CLAUDE.md
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
|
@ -8,7 +6,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
|||
|
||||
Provider-agnostic sliding window rate limiter for LLM API calls. Single Go package (no sub-packages) with two persistence backends: YAML (single-process, default) and SQLite (multi-process, WAL mode). Enforces RPM, TPM, and RPD quotas per model. Ships default profiles for Gemini, OpenAI, Anthropic, and Local providers.
|
||||
|
||||
Module: `dappco.re/go/core/go-ratelimit` — Go 1.26, no CGO required.
|
||||
Module: `forge.lthn.ai/core/go-ratelimit` — Go 1.26, no CGO required.
|
||||
|
||||
## Commands
|
||||
|
||||
|
|
@ -30,7 +28,7 @@ Pre-commit gate: `go test -race ./...` and `go vet ./...` must both pass.
|
|||
- **Conventional commits**: `type(scope): description` — scopes: `ratelimit`, `sqlite`, `persist`, `config`
|
||||
- **Co-Author line** on every commit: `Co-Authored-By: Virgil <virgil@lethean.io>`
|
||||
- **Coverage** must not drop below 95%
|
||||
- **Error format**: `core.E("ratelimit.FunctionName", "what", err)` via `dappco.re/go/core` — lowercase, no trailing punctuation
|
||||
- **Error format**: `coreerr.E("ratelimit.FunctionName", "what", err)` via `go-log` — lowercase, no trailing punctuation
|
||||
- **No `init()` functions**, no global mutable state
|
||||
- **Mutex discipline**: lock at the top of public methods, never inside helpers. Helpers that need the lock document "Caller must hold the lock". `prune()` mutates state, so even "read-only" methods that call it take the write lock. Never call a public method from another public method while holding the lock.
|
||||
|
||||
|
|
@ -62,9 +60,10 @@ SQLite tests use `_Good`/`_Bad`/`_Ugly` suffixes (happy path / expected errors /
|
|||
|
||||
## Dependencies
|
||||
|
||||
Four direct dependencies — do not add more without justification:
|
||||
Five direct dependencies — do not add more without justification:
|
||||
|
||||
- `dappco.re/go/core` — file I/O helpers, structured errors, JSON helpers, path/environment utilities
|
||||
- `forge.lthn.ai/core/go-io` — file I/O abstraction
|
||||
- `forge.lthn.ai/core/go-log` — structured error handling (`coreerr.E`)
|
||||
- `gopkg.in/yaml.v3` — YAML backend
|
||||
- `modernc.org/sqlite` — pure Go SQLite (no CGO)
|
||||
- `github.com/stretchr/testify` — test-only
|
||||
|
|
|
|||
|
|
@ -1,21 +1,15 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
# Contributing
|
||||
|
||||
Thank you for your interest in contributing!
|
||||
|
||||
## Requirements
|
||||
- **Go Version**: 1.26 or higher is required.
|
||||
- **Tools**: `golangci-lint` is recommended.
|
||||
- **Tools**: `golangci-lint` and `task` (Taskfile.dev) are recommended.
|
||||
|
||||
## Development Workflow
|
||||
1. **Testing**: Ensure all tests pass before submitting changes.
|
||||
```bash
|
||||
go build ./...
|
||||
go test ./...
|
||||
go test -race ./...
|
||||
go test -cover ./...
|
||||
go mod tidy
|
||||
```
|
||||
2. **Code Style**: All code must follow standard Go formatting.
|
||||
```bash
|
||||
|
|
@ -28,22 +22,14 @@ Thank you for your interest in contributing!
|
|||
```
|
||||
|
||||
## Commit Message Format
|
||||
We follow the [Conventional Commits](https://www.conventionalcommits.org/) specification using the repository format `type(scope): description`:
|
||||
We follow the [Conventional Commits](https://www.conventionalcommits.org/) specification:
|
||||
- `feat`: A new feature
|
||||
- `fix`: A bug fix
|
||||
- `docs`: Documentation changes
|
||||
- `refactor`: A code change that neither fixes a bug nor adds a feature
|
||||
- `chore`: Changes to the build process or auxiliary tools and libraries
|
||||
|
||||
Common scopes: `ratelimit`, `sqlite`, `persist`, `config`
|
||||
Example: `feat: add new endpoint for health check`
|
||||
|
||||
Example:
|
||||
|
||||
```text
|
||||
fix(ratelimit): align module metadata with dappco.re
|
||||
|
||||
Co-Authored-By: Virgil <virgil@lethean.io>
|
||||
```
|
||||
|
||||
## Licence
|
||||
## License
|
||||
By contributing to this project, you agree that your contributions will be licensed under the **European Union Public Licence (EUPL-1.2)**.
|
||||
|
|
|
|||
44
README.md
44
README.md
|
|
@ -1,50 +1,30 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
[](https://pkg.go.dev/dappco.re/go/core/go-ratelimit)
|
||||

|
||||
[](https://pkg.go.dev/forge.lthn.ai/core/go-ratelimit)
|
||||
[](LICENSE.md)
|
||||
[](go.mod)
|
||||
|
||||
# go-ratelimit
|
||||
|
||||
Provider-agnostic sliding window rate limiter for LLM API calls. Enforces requests per minute (RPM), tokens per minute (TPM), and requests per day (RPD) quotas per model using an in-memory sliding window. Ships with default quota profiles for Gemini, OpenAI, Anthropic, and a local inference provider. State persists across process restarts via YAML (single-process) or SQLite (multi-process, WAL mode). Includes a Gemini-specific token counting helper and a YAML-to-SQLite migration path.
|
||||
|
||||
**Module**: `dappco.re/go/core/go-ratelimit`
|
||||
**Module**: `forge.lthn.ai/core/go-ratelimit`
|
||||
**Licence**: EUPL-1.2
|
||||
**Language**: Go 1.26
|
||||
**Language**: Go 1.25
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import "dappco.re/go/core/go-ratelimit"
|
||||
import "forge.lthn.ai/core/go-ratelimit"
|
||||
|
||||
// YAML backend (default, single-process)
|
||||
rl, err := ratelimit.New()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// SQLite backend (multi-process)
|
||||
rl, err = ratelimit.NewWithSQLite("/tmp/ratelimits.db")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
rl, err := ratelimit.NewWithSQLite("~/.core/ratelimits.db")
|
||||
defer rl.Close()
|
||||
|
||||
if rl.CanSend("gemini-2.0-flash", 1500) {
|
||||
rl.RecordUsage("gemini-2.0-flash", 1000, 500)
|
||||
}
|
||||
|
||||
if err := rl.Persist(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
```
|
||||
|
||||
For agent workflows, `Decide` returns a structured verdict with retry guidance:
|
||||
|
||||
```go
|
||||
decision := rl.Decide("gemini-2.0-flash", 1500)
|
||||
if !decision.Allowed {
|
||||
log.Printf("throttled (%s); retry after %s", decision.Code, decision.RetryAfter)
|
||||
ok, reason := rl.CanSend("gemini-2.0-flash", 1500)
|
||||
if ok {
|
||||
rl.RecordUsage("gemini-2.0-flash", 1500)
|
||||
}
|
||||
```
|
||||
|
||||
|
|
@ -57,14 +37,12 @@ if !decision.Allowed {
|
|||
## Build & Test
|
||||
|
||||
```bash
|
||||
go build ./...
|
||||
go test ./...
|
||||
go test -race ./...
|
||||
go vet ./...
|
||||
go test -cover ./...
|
||||
go mod tidy
|
||||
go build ./...
|
||||
```
|
||||
|
||||
## Licence
|
||||
|
||||
European Union Public Licence 1.2.
|
||||
European Union Public Licence 1.2 — see [LICENCE](LICENCE) for details.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
# API Contract
|
||||
|
||||
Test coverage is marked `yes` when the symbol is exercised by the existing test suite in `ratelimit_test.go`, `sqlite_test.go`, `error_test.go`, or `iter_test.go`.
|
||||
|
|
@ -14,8 +12,6 @@ Test coverage is marked `yes` when the symbol is exercised by the existing test
|
|||
| Type | `UsageStats` | `type UsageStats struct { Requests []time.Time; Tokens []TokenEntry; DayStart time.Time; DayCount int }` | Stores per-model sliding-window request and token history plus rolling daily usage state. | yes |
|
||||
| Type | `RateLimiter` | `type RateLimiter struct { Quotas map[string]ModelQuota; State map[string]*UsageStats }` | Manages quotas, usage state, persistence, and concurrency across models. | yes |
|
||||
| Type | `ModelStats` | `type ModelStats struct { RPM int; MaxRPM int; TPM int; MaxTPM int; RPD int; MaxRPD int; DayStart time.Time }` | Represents a snapshot of current usage and configured limits for a model. | yes |
|
||||
| Type | `DecisionCode` | `type DecisionCode string` | Machine-readable allow/deny codes returned by `Decide` (e.g., `ok`, `rpm_exceeded`). | yes |
|
||||
| Type | `Decision` | `type Decision struct { Allowed bool; Code DecisionCode; Reason string; RetryAfter time.Duration; Stats ModelStats }` | Structured decision result with a code, human-readable reason, optional retry guidance, and a stats snapshot. | yes |
|
||||
| Function | `DefaultProfiles` | `func DefaultProfiles() map[Provider]ProviderProfile` | Returns the built-in quota profiles for the supported providers. | yes |
|
||||
| Function | `New` | `func New() (*RateLimiter, error)` | Creates a new limiter with Gemini defaults for backward-compatible YAML-backed usage. | yes |
|
||||
| Function | `NewWithConfig` | `func NewWithConfig(cfg Config) (*RateLimiter, error)` | Creates a YAML-backed limiter from explicit configuration, defaulting to Gemini when config is empty. | yes |
|
||||
|
|
@ -29,9 +25,8 @@ Test coverage is marked `yes` when the symbol is exercised by the existing test
|
|||
| Method | `Persist` | `func (rl *RateLimiter) Persist() error` | Persists a snapshot of quotas and usage state to YAML or SQLite. | yes |
|
||||
| Method | `BackgroundPrune` | `func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func()` | Starts periodic pruning of expired usage state and returns a stop function. | yes |
|
||||
| Method | `CanSend` | `func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool` | Reports whether a request with the estimated token count fits within current limits. | yes |
|
||||
| Method | `Decide` | `func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision` | Returns structured allow/deny information including code, reason, retry guidance, and stats snapshot without recording usage. | yes |
|
||||
| Method | `RecordUsage` | `func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)` | Records a successful request into the sliding-window and daily counters. | yes |
|
||||
| Method | `WaitForCapacity` | `func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error` | Blocks until `Decide` allows the request, sleeping according to `RetryAfter` hints or one-second polls. | yes |
|
||||
| Method | `WaitForCapacity` | `func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error` | Blocks until `CanSend` succeeds or the context is cancelled. | yes |
|
||||
| Method | `Reset` | `func (rl *RateLimiter) Reset(model string)` | Clears usage state for one model or for all models when `model` is empty. | yes |
|
||||
| Method | `Models` | `func (rl *RateLimiter) Models() iter.Seq[string]` | Returns a sorted iterator of all model names known from quotas or state. | yes |
|
||||
| Method | `Iter` | `func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats]` | Returns a sorted iterator of model names paired with current stats snapshots. | yes |
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
---
|
||||
title: Architecture
|
||||
description: Internals of go-ratelimit -- sliding window algorithm, provider quota system, persistence backends, and concurrency model.
|
||||
|
|
@ -12,7 +10,7 @@ three independent quota dimensions per model -- requests per minute (RPM), token
|
|||
per minute (TPM), and requests per day (RPD) -- using an in-memory sliding window
|
||||
that can be persisted across process restarts via YAML or SQLite.
|
||||
|
||||
Module path: `dappco.re/go/core/go-ratelimit`
|
||||
Module path: `forge.lthn.ai/core/go-ratelimit`
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -119,12 +117,6 @@ The check order is: RPD, then RPM, then TPM. RPD is checked first because it
|
|||
is the cheapest comparison (a single integer). TPM is checked last because it
|
||||
requires summing the token counts in the sliding window.
|
||||
|
||||
`Decide()` follows the same path as `CanSend()` but returns a structured
|
||||
`Decision` containing a machine-readable code, reason, `RetryAfter` guidance,
|
||||
and a `ModelStats` snapshot. It is agent-facing and does not record usage;
|
||||
`WaitForCapacity()` consumes its `RetryAfter` hint to avoid unnecessary
|
||||
one-second polling when limits are saturated.
|
||||
|
||||
### Daily Reset
|
||||
|
||||
The daily counter resets automatically inside `prune()`. When
|
||||
|
|
@ -260,7 +252,7 @@ state:
|
|||
day_count: 42
|
||||
```
|
||||
|
||||
`Persist()` creates parent directories with the `core.Fs` helper before writing.
|
||||
`Persist()` creates parent directories with `os.MkdirAll` before writing.
|
||||
`Load()` treats a missing file as an empty state (no error). Corrupt or
|
||||
unreadable files return an error.
|
||||
|
||||
|
|
@ -325,8 +317,8 @@ precision and allows efficient range queries using the composite indices.
|
|||
|
||||
### Save Strategy
|
||||
|
||||
- **Quotas**: full snapshot replace inside a single transaction. `saveQuotas()`
|
||||
clears the table and reinserts the current quota map.
|
||||
- **Quotas**: `INSERT ... ON CONFLICT(model) DO UPDATE` (upsert). Existing quota
|
||||
rows are updated in place without deleting unrelated models.
|
||||
- **State**: Delete-then-insert inside a single transaction. All three state
|
||||
tables (`requests`, `tokens`, `daily`) are truncated and rewritten atomically.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
---
|
||||
title: Development Guide
|
||||
description: How to build, test, and contribute to go-ratelimit -- prerequisites, test patterns, coding standards, and commit conventions.
|
||||
|
|
@ -20,9 +18,6 @@ No C toolchain, no system SQLite library, no external build tools. A plain
|
|||
## Build and Test
|
||||
|
||||
```bash
|
||||
# Compile all packages
|
||||
go build ./...
|
||||
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
|
|
@ -47,16 +42,12 @@ go vet ./...
|
|||
# Lint (requires golangci-lint)
|
||||
golangci-lint run ./...
|
||||
|
||||
# Coverage check
|
||||
go test -cover ./...
|
||||
|
||||
# Tidy dependencies
|
||||
go mod tidy
|
||||
```
|
||||
|
||||
Before a commit is pushed, `go build ./...`, `go test -race ./...`,
|
||||
`go vet ./...`, `go test -cover ./...`, and `go mod tidy` must all pass
|
||||
without errors, and coverage must remain at or above 95%.
|
||||
All three commands (`go test -race ./...`, `go vet ./...`, and `go mod tidy`)
|
||||
must produce no errors or warnings before a commit is pushed.
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -156,8 +147,17 @@ The following benchmarks are included:
|
|||
|
||||
### Coverage
|
||||
|
||||
Maintain at least 95% statement coverage. Verify it with `go test -cover ./...`
|
||||
and document any justified exception in the commit or PR that introduces it.
|
||||
Current coverage: 95.1%. The remaining paths cannot be covered in unit tests
|
||||
without modifying production code:
|
||||
|
||||
1. `CountTokens` success path -- the Google API URL is hardcoded; unit tests
|
||||
cannot intercept the HTTP call without URL injection support.
|
||||
2. `yaml.Marshal` error path in `Persist()` -- `yaml.Marshal` does not fail on
|
||||
valid Go structs.
|
||||
3. `os.UserHomeDir()` error path in `NewWithConfig()` -- triggered only when
|
||||
`$HOME` is unset, which test infrastructure prevents.
|
||||
|
||||
Do not lower coverage below 95% without a documented reason.
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -172,8 +172,8 @@ Do not use American spellings in identifiers, comments, or documentation.
|
|||
|
||||
- All exported types, functions, and fields must have doc comments.
|
||||
- Error strings must be lowercase and not end with punctuation (Go convention).
|
||||
- Contextual errors use `core.E("ratelimit.Function", "what", err)` so errors
|
||||
identify their origin clearly.
|
||||
- Contextual errors use `fmt.Errorf("ratelimit.Function: what: %w", err)` so
|
||||
errors identify their origin clearly.
|
||||
- No `init()` functions.
|
||||
- No global mutable state. `DefaultProfiles()` returns a fresh map on each call.
|
||||
|
||||
|
|
@ -196,7 +196,6 @@ Direct dependencies are intentionally minimal:
|
|||
|
||||
| Dependency | Purpose |
|
||||
|------------|---------|
|
||||
| `dappco.re/go/core` | File I/O helpers, structured errors, JSON helpers, path/environment utilities |
|
||||
| `gopkg.in/yaml.v3` | YAML serialisation for legacy backend |
|
||||
| `modernc.org/sqlite` | Pure Go SQLite for persistent backend |
|
||||
| `github.com/stretchr/testify` | Test assertions (test-only) |
|
||||
|
|
|
|||
|
|
@ -1,13 +1,10 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
# Project History
|
||||
|
||||
## Origin
|
||||
|
||||
go-ratelimit was extracted from the `pkg/ratelimit` package inside
|
||||
`forge.lthn.ai/core/go` on 19 February 2026. The package now lives at
|
||||
`dappco.re/go/core/go-ratelimit`, with its own repository and independent
|
||||
development cadence.
|
||||
`forge.lthn.ai/core/go` on 19 February 2026. The extraction gave the package
|
||||
its own module path, repository, and independent development cadence.
|
||||
|
||||
Initial commit: `fa1a6fc` — `feat: extract go-ratelimit from core/go pkg/ratelimit`
|
||||
|
||||
|
|
@ -28,7 +25,7 @@ Commit: `3c63b10` — `feat(ratelimit): generalise beyond Gemini with provider p
|
|||
|
||||
Supplementary commit: `db958f2` — `test: expand race coverage and benchmarks`
|
||||
|
||||
Coverage increased from 77.1% to above the 95% floor. The test suite was rewritten using
|
||||
Coverage increased from 77.1% to 95.1%. The test suite was rewritten using
|
||||
testify with table-driven subtests throughout.
|
||||
|
||||
### Tests added
|
||||
|
|
@ -61,6 +58,18 @@ testify with table-driven subtests throughout.
|
|||
- `BenchmarkAllStats` — 5 models x 200 entries
|
||||
- `BenchmarkPersist` — YAML I/O
|
||||
|
||||
### Remaining uncovered paths (5%)
|
||||
|
||||
These three paths are structurally impossible to cover in unit tests without
|
||||
modifying production code:
|
||||
|
||||
1. `CountTokens` success path — the Google API URL is hardcoded; unit tests
|
||||
cannot intercept the HTTP call without URL injection support
|
||||
2. `yaml.Marshal` error path in `Persist()` — `yaml.Marshal` does not fail on
|
||||
valid Go structs; the error branch exists for correctness only
|
||||
3. `os.UserHomeDir()` error path in `NewWithConfig()` — triggered only when
|
||||
`$HOME` is unset, which test infrastructure prevents
|
||||
|
||||
`go test -race ./...` passed clean. `go vet ./...` produced no warnings.
|
||||
|
||||
---
|
||||
|
|
@ -130,7 +139,7 @@ established elsewhere in the ecosystem.
|
|||
|
||||
- `TestNewSQLiteStore_Good / _Bad` — creation and invalid path handling
|
||||
- `TestSQLiteQuotasRoundTrip_Good` — save/load round-trip
|
||||
- `TestSQLite_QuotasOverwrite_Good` — the latest quota snapshot replaces previous rows
|
||||
- `TestSQLiteQuotasUpsert_Good` — upsert replaces existing rows
|
||||
- `TestSQLiteStateRoundTrip_Good` — multi-model state with nanosecond precision
|
||||
- `TestSQLiteStateOverwrite_Good` — delete-then-insert atomicity
|
||||
- `TestSQLiteEmptyState_Good` — fresh database returns empty maps
|
||||
|
|
@ -159,10 +168,11 @@ Not yet implemented. Intended downstream integrations:
|
|||
|
||||
## Known Limitations
|
||||
|
||||
**CountTokens URL is hardcoded.** The exported `CountTokens` helper calls
|
||||
`generativelanguage.googleapis.com` directly. Callers cannot redirect it to
|
||||
Gemini-compatible proxies or alternate endpoints without going through an
|
||||
internal helper or refactoring the API to accept a base URL or `http.Client`.
|
||||
**CountTokens URL is hardcoded.** The `CountTokens` helper calls
|
||||
`generativelanguage.googleapis.com` directly. There is no way to override the
|
||||
base URL, which prevents testing the success path in unit tests and prevents
|
||||
use with Gemini-compatible proxies. A future refactor would accept a base URL
|
||||
parameter or an `http.Client`.
|
||||
|
||||
**saveState is a full table replace.** On every `Persist()` call, the `requests`,
|
||||
`tokens`, and `daily` tables are truncated and rewritten. For a limiter tracking
|
||||
|
|
@ -176,10 +186,10 @@ SQLite on `Persist()`. The database does not grow unboundedly between persist
|
|||
cycles because `saveState` replaces all rows, but if `Persist()` is called
|
||||
frequently the WAL file can grow transiently.
|
||||
|
||||
**WaitForCapacity now sleeps using `Decide`’s `RetryAfter` hint** (with a
|
||||
one-second fallback when no hint exists). This reduces busy looping on long
|
||||
windows but remains coarse for sub-second smoothing; callers that need
|
||||
sub-second pacing should implement their own loop.
|
||||
**WaitForCapacity polling interval is fixed at 1 second.** This is appropriate
|
||||
for RPM-scale limits but is coarse for sub-second limits. If a caller needs
|
||||
finer-grained waiting (e.g., smoothing requests within a minute), they must
|
||||
implement their own loop.
|
||||
|
||||
**No automatic persistence.** `Persist()` must be called explicitly. If a
|
||||
process exits without calling `Persist()`, any usage recorded since the last
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
---
|
||||
title: go-ratelimit
|
||||
description: Provider-agnostic sliding window rate limiter for LLM API calls, with YAML and SQLite persistence backends.
|
||||
|
|
@ -7,7 +5,7 @@ description: Provider-agnostic sliding window rate limiter for LLM API calls, wi
|
|||
|
||||
# go-ratelimit
|
||||
|
||||
**Module**: `dappco.re/go/core/go-ratelimit`
|
||||
**Module**: `forge.lthn.ai/core/go-ratelimit`
|
||||
**Licence**: EUPL-1.2
|
||||
**Go version**: 1.26+
|
||||
|
||||
|
|
@ -21,7 +19,7 @@ migration helper is included.
|
|||
## Quick Start
|
||||
|
||||
```go
|
||||
import "dappco.re/go/core/go-ratelimit"
|
||||
import "forge.lthn.ai/core/go-ratelimit"
|
||||
|
||||
// Create a limiter with Gemini defaults (YAML backend).
|
||||
rl, err := ratelimit.New()
|
||||
|
|
@ -86,8 +84,6 @@ if err := rl.WaitForCapacity(ctx, "claude-opus-4", 2000); err != nil {
|
|||
return
|
||||
}
|
||||
// Capacity is available; proceed with the API call.
|
||||
|
||||
// WaitForCapacity uses Decide's RetryAfter hint to avoid tight polling.
|
||||
```
|
||||
|
||||
## Package Layout
|
||||
|
|
@ -107,7 +103,6 @@ The module is a single package with no sub-packages.
|
|||
|
||||
| Dependency | Purpose | Category |
|
||||
|------------|---------|----------|
|
||||
| `dappco.re/go/core` | File I/O helpers, structured errors, JSON helpers, path/environment utilities | Direct |
|
||||
| `gopkg.in/yaml.v3` | YAML serialisation for the legacy persistence backend | Direct |
|
||||
| `modernc.org/sqlite` | Pure Go SQLite driver (no CGO required) | Direct |
|
||||
| `github.com/stretchr/testify` | Test assertions (`assert`, `require`) | Test only |
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
<!-- SPDX-License-Identifier: EUPL-1.2 -->
|
||||
|
||||
# Security Attack Vector Mapping
|
||||
|
||||
Scope: external inputs that cross into this package from callers, persisted storage, or the network. This is a mapping only; it does not propose or apply fixes.
|
||||
|
|
@ -16,7 +14,7 @@ Note: `CODEX.md` was not present anywhere under `/workspace` during this scan, s
|
|||
| `(*RateLimiter).BackgroundPrune(interval time.Duration)` | `ratelimit.go:328` | Caller-controlled `interval` | Passed to `time.NewTicker(interval)` and drives a background goroutine that repeatedly locks and prunes state | None | `interval <= 0` causes a panic; very small intervals can create CPU and lock-contention DoS; repeated calls without using the returned cancel function leak goroutines |
|
||||
| `(*RateLimiter).CanSend(model string, estimatedTokens int)` | `ratelimit.go:350` | Caller-controlled `model` and `estimatedTokens` | `model` indexes `rl.Quotas` / `rl.State`; `estimatedTokens` is added to the current token total before the TPM comparison | Unknown models are allowed immediately; no non-negative or range checks on `estimatedTokens` | Passing an unconfigured model name bypasses throttling entirely; negative or overflowed token values can undercount the TPM check and permit oversend |
|
||||
| `(*RateLimiter).RecordUsage(model string, promptTokens, outputTokens int)` | `ratelimit.go:396` | Caller-controlled `model`, `promptTokens`, `outputTokens` | Creates or updates `rl.State[model]`; stores `promptTokens + outputTokens` in the token window and increments `DayCount` | None | Arbitrary model names create unbounded state that will later persist to YAML/SQLite; negative or overflowed token totals poison accounting and can reduce future TPM totals below the real usage |
|
||||
| `(*RateLimiter).WaitForCapacity(ctx context.Context, model string, tokens int)` | `ratelimit.go:429` | Caller-controlled `ctx`, `model`, `tokens` | Calls `Decide(model, tokens)` in a loop and sleeps for the returned `RetryAfter` (or 1s fallback) until allowed or `ctx.Done()` fires | No direct validation beyond negative-token guard; relies on downstream `Decide()` and caller-supplied context cancellation | Long `RetryAfter` values can delay rechecks; repeated calls with long-lived contexts can still accumulate goroutines and lock pressure |
|
||||
| `(*RateLimiter).WaitForCapacity(ctx context.Context, model string, tokens int)` | `ratelimit.go:414` | Caller-controlled `ctx`, `model`, `tokens` | Calls `CanSend(model, tokens)` once per second until capacity is available or `ctx.Done()` fires | No direct validation; relies on downstream `CanSend()` and caller-supplied context cancellation | Inherits the unknown-model and negative-token bypasses from `CanSend()`; repeated calls with long-lived contexts can accumulate goroutines and lock pressure |
|
||||
| `(*RateLimiter).Reset(model string)` | `ratelimit.go:433` | Caller-controlled `model` | `model == ""` replaces the entire `rl.State` map; otherwise `delete(rl.State, model)` | Empty string is treated as a wildcard reset | If reachable by an untrusted actor, an empty string clears all rate-limit history and targeted resets erase throttling state for chosen models |
|
||||
| `(*RateLimiter).Stats(model string)` | `ratelimit.go:484` | Caller-controlled `model` | Prunes `rl.State[model]`, reads `rl.Quotas[model]`, and returns a usage snapshot | None | If exposed through a service boundary, it discloses per-model quota ceilings and live usage counts that can help an attacker tune evasion or timing |
|
||||
| `NewWithSQLite(dbPath string)` | `ratelimit.go:567` | Caller-controlled `dbPath` | Thin wrapper that forwards `dbPath` into `NewWithSQLiteConfig()` and then `newSQLiteStore()` | No additional validation in the wrapper | Untrusted `dbPath` can steer database creation/opening to unintended local filesystem locations, including companion `-wal` and `-shm` files |
|
||||
|
|
|
|||
158
error_test.go
158
error_test.go
|
|
@ -1,9 +1,8 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -11,8 +10,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestError_SQLiteErrorPaths_Bad(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "error.db")
|
||||
func TestSQLiteErrorPaths(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "error.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -40,17 +39,17 @@ func TestError_SQLiteErrorPaths_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteInitErrors_Bad(t *testing.T) {
|
||||
func TestSQLiteInitErrors(t *testing.T) {
|
||||
t.Run("WAL pragma failure", func(t *testing.T) {
|
||||
// This is hard to trigger without mocking sql.DB, but we can try an invalid connection string
|
||||
// modernc.org/sqlite doesn't support all DSN options that might cause PRAGMA to fail but connection to succeed.
|
||||
})
|
||||
}
|
||||
|
||||
func TestError_PersistYAML_Good(t *testing.T) {
|
||||
func TestPersistYAML(t *testing.T) {
|
||||
t.Run("successful YAML persist and load", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "ratelimits.yaml")
|
||||
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||
rl, _ := New()
|
||||
rl.filePath = path
|
||||
rl.Quotas["test"] = ModelQuota{MaxRPM: 1}
|
||||
|
|
@ -66,9 +65,9 @@ func TestError_PersistYAML_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) {
|
||||
func TestSQLiteLoadViaLimiter(t *testing.T) {
|
||||
t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "load-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "load-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -80,7 +79,7 @@ func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Load returns error when loadState fails", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "load-state-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "load-state-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -97,9 +96,9 @@ func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) {
|
||||
func TestSQLitePersistViaLimiter(t *testing.T) {
|
||||
t.Run("Persist returns error when SQLite saveQuotas fails", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -114,7 +113,7 @@ func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Persist returns error when SQLite saveState fails", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist-state-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-state-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -131,7 +130,7 @@ func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_NewWithSQLite_Bad(t *testing.T) {
|
||||
func TestNewWithSQLiteErrors(t *testing.T) {
|
||||
t.Run("NewWithSQLite with invalid path", func(t *testing.T) {
|
||||
_, err := NewWithSQLite("/nonexistent/deep/nested/dir/test.db")
|
||||
assert.Error(t, err, "should fail with invalid path")
|
||||
|
|
@ -145,9 +144,9 @@ func TestError_NewWithSQLite_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
||||
func TestSQLiteSaveStateErrors(t *testing.T) {
|
||||
t.Run("saveState fails when tokens table is dropped", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "tokens-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "tokens-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -167,7 +166,7 @@ func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails when daily table is dropped", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "daily-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "daily-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -186,7 +185,7 @@ func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on request insert with renamed column", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "req-insert-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "req-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -207,7 +206,7 @@ func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on token insert with renamed column", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "tok-insert-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "tok-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -228,7 +227,7 @@ func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on daily insert with renamed column", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "day-insert-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "day-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -248,9 +247,9 @@ func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteLoadState_Bad(t *testing.T) {
|
||||
func TestSQLiteLoadStateErrors(t *testing.T) {
|
||||
t.Run("loadState fails when requests table is dropped", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "req-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "req-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -272,7 +271,7 @@ func TestError_SQLiteLoadState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("loadState fails when tokens table is dropped", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "tok-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "tok-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -294,7 +293,7 @@ func TestError_SQLiteLoadState_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("loadState fails when daily table is dropped", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "daily-load-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "daily-load-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -315,9 +314,9 @@ func TestError_SQLiteLoadState_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) {
|
||||
func TestSQLiteSaveQuotasExecError(t *testing.T) {
|
||||
t.Run("saveQuotas fails with renamed column at prepare", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "quota-exec-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-exec-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -333,7 +332,7 @@ func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveQuotas fails at exec via trigger", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "quota-trigger.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-trigger.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -351,9 +350,9 @@ func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
|
||||
func TestSQLiteSaveStateExecErrors(t *testing.T) {
|
||||
t.Run("request insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "trigger-req.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-req.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -376,7 +375,7 @@ func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("token insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "trigger-tok.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-tok.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -399,7 +398,7 @@ func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("daily insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "trigger-day.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-day.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -421,9 +420,9 @@ func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteLoadQuotasScan_Bad(t *testing.T) {
|
||||
func TestSQLiteLoadQuotasScanError(t *testing.T) {
|
||||
t.Run("loadQuotas fails with renamed column", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "quota-scan-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-scan-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -442,29 +441,26 @@ func TestError_SQLiteLoadQuotasScan_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_NewSQLiteStoreInReadOnlyDir_Bad(t *testing.T) {
|
||||
if isRootUser() {
|
||||
func TestNewSQLiteStoreInReadOnlyDir(t *testing.T) {
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod restrictions do not apply to root")
|
||||
}
|
||||
|
||||
t.Run("fails when parent directory is read-only", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
readonlyDir := testPath(tmpDir, "readonly")
|
||||
ensureTestDir(t, readonlyDir)
|
||||
setPathMode(t, readonlyDir, 0o555)
|
||||
defer func() {
|
||||
_ = syscall.Chmod(readonlyDir, 0o755)
|
||||
}()
|
||||
readonlyDir := filepath.Join(tmpDir, "readonly")
|
||||
require.NoError(t, os.MkdirAll(readonlyDir, 0555))
|
||||
defer os.Chmod(readonlyDir, 0755)
|
||||
|
||||
dbPath := testPath(readonlyDir, "test.db")
|
||||
dbPath := filepath.Join(readonlyDir, "test.db")
|
||||
_, err := newSQLiteStore(dbPath)
|
||||
assert.Error(t, err, "should fail when directory is read-only")
|
||||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteCreateSchema_Bad(t *testing.T) {
|
||||
func TestSQLiteCreateSchemaError(t *testing.T) {
|
||||
t.Run("createSchema fails on closed DB", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "schema-err.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "schema-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -477,9 +473,9 @@ func TestError_SQLiteCreateSchema_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
|
||||
func TestSQLiteLoadStateScanErrors(t *testing.T) {
|
||||
t.Run("scan daily fails with NULL values", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "scan-daily.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-daily.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -499,7 +495,7 @@ func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("scan requests fails with NULL ts", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "scan-req.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-req.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -524,7 +520,7 @@ func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("scan tokens fails with NULL values", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "scan-tok.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-tok.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -549,9 +545,9 @@ func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_SQLiteLoadQuotasScanWithBadSchema_Bad(t *testing.T) {
|
||||
func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) {
|
||||
t.Run("scan fails with NULL quota values", func(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "scan-quota.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-quota.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -570,11 +566,11 @@ func TestError_SQLiteLoadQuotasScanWithBadSchema_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
||||
t.Run("saveQuotas failure during migration via trigger", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "with-quotas.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrate-quota-err.db")
|
||||
yamlPath := filepath.Join(tmpDir, "with-quotas.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrate-quota-err.db")
|
||||
|
||||
// Write a YAML file with quotas.
|
||||
yamlData := `quotas:
|
||||
|
|
@ -583,7 +579,7 @@ func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
|||
max_tpm: 100
|
||||
max_rpd: 50
|
||||
`
|
||||
writeTestFile(t, yamlPath, yamlData)
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
|
||||
|
||||
// Pre-create DB with a trigger that aborts quota inserts.
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -600,8 +596,8 @@ func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
|||
|
||||
t.Run("saveState failure during migration via trigger", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "with-state.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrate-state-err.db")
|
||||
yamlPath := filepath.Join(tmpDir, "with-state.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrate-state-err.db")
|
||||
|
||||
// Write YAML with state.
|
||||
yamlData := `state:
|
||||
|
|
@ -611,7 +607,7 @@ func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
|||
day_start: 2026-01-01T00:00:00Z
|
||||
day_count: 1
|
||||
`
|
||||
writeTestFile(t, yamlPath, yamlData)
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
|
||||
|
||||
// Pre-create DB with a trigger that aborts daily inserts.
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -626,13 +622,13 @@ func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_MigrateYAMLToSQLiteNilQuotasAndState_Good(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) {
|
||||
t.Run("YAML with empty quotas and state migrates cleanly", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "empty.yaml")
|
||||
writeTestFile(t, yamlPath, "{}")
|
||||
yamlPath := filepath.Join(tmpDir, "empty.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("{}"), 0644))
|
||||
|
||||
sqlitePath := testPath(tmpDir, "empty.db")
|
||||
sqlitePath := filepath.Join(tmpDir, "empty.db")
|
||||
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))
|
||||
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -649,18 +645,30 @@ func TestError_MigrateYAMLToSQLiteNilQuotasAndState_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestError_NewWithConfigHomeUnavailable_Bad(t *testing.T) {
|
||||
// Clear all supported home env vars so defaultStatePath cannot resolve a home directory.
|
||||
t.Setenv("CORE_HOME", "")
|
||||
t.Setenv("HOME", "")
|
||||
t.Setenv("home", "")
|
||||
t.Setenv("USERPROFILE", "")
|
||||
func TestNewWithConfigUserHomeDirError(t *testing.T) {
|
||||
// Unset HOME to trigger os.UserHomeDir() error.
|
||||
home := os.Getenv("HOME")
|
||||
os.Unsetenv("HOME")
|
||||
// Also unset fallback env vars that UserHomeDir checks.
|
||||
plan9Home := os.Getenv("home")
|
||||
os.Unsetenv("home")
|
||||
userProfile := os.Getenv("USERPROFILE")
|
||||
os.Unsetenv("USERPROFILE")
|
||||
defer func() {
|
||||
os.Setenv("HOME", home)
|
||||
if plan9Home != "" {
|
||||
os.Setenv("home", plan9Home)
|
||||
}
|
||||
if userProfile != "" {
|
||||
os.Setenv("USERPROFILE", userProfile)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err := NewWithConfig(Config{})
|
||||
assert.Error(t, err, "should fail when HOME is unset")
|
||||
}
|
||||
|
||||
func TestError_PersistMarshal_Good(t *testing.T) {
|
||||
func TestPersistMarshalError(t *testing.T) {
|
||||
// yaml.Marshal on a struct with map[string]ModelQuota and map[string]*UsageStats
|
||||
// should not fail in practice. We test the error path by using a type that
|
||||
// yaml.Marshal cannot handle: a channel.
|
||||
|
|
@ -672,20 +680,20 @@ func TestError_PersistMarshal_Good(t *testing.T) {
|
|||
assert.NoError(t, rl.Persist(), "valid persist should succeed")
|
||||
}
|
||||
|
||||
func TestError_MigrateErrorsExtended_Bad(t *testing.T) {
|
||||
func TestMigrateErrorsExtended(t *testing.T) {
|
||||
t.Run("unmarshal failure", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "bad.yaml")
|
||||
writeTestFile(t, path, "invalid: yaml: [")
|
||||
err := MigrateYAMLToSQLite(path, testPath(tmpDir, "out.db"))
|
||||
path := filepath.Join(tmpDir, "bad.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("invalid: yaml: ["), 0644))
|
||||
err := MigrateYAMLToSQLite(path, filepath.Join(tmpDir, "out.db"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ratelimit.MigrateYAMLToSQLite: unmarshal")
|
||||
})
|
||||
|
||||
t.Run("sqlite open failure", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "ok.yaml")
|
||||
writeTestFile(t, yamlPath, "quotas: {}")
|
||||
yamlPath := filepath.Join(tmpDir, "ok.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("quotas: {}"), 0644))
|
||||
// Use an invalid sqlite path (dir where file should be)
|
||||
err := MigrateYAMLToSQLite(yamlPath, "/dev/null/not-a-db")
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
12
go.mod
12
go.mod
|
|
@ -1,24 +1,24 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
module dappco.re/go/core/go-ratelimit
|
||||
module forge.lthn.ai/core/go-ratelimit
|
||||
|
||||
go 1.26.0
|
||||
|
||||
require (
|
||||
dappco.re/go/core v0.8.0-alpha.1
|
||||
dappco.re/go/core/io v0.2.0
|
||||
dappco.re/go/core/log v0.1.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.47.0
|
||||
)
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/go-log v0.0.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.42.0 // indirect
|
||||
golang.org/x/tools v0.43.0 // indirect
|
||||
modernc.org/libc v1.70.0 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
|
|
|
|||
9
go.sum
9
go.sum
|
|
@ -1,6 +1,9 @@
|
|||
dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
|
||||
dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
dappco.re/go/core/io v0.2.0 h1:zuudgIiTsQQ5ipVt97saWdGLROovbEB/zdVyy9/l+I4=
|
||||
dappco.re/go/core/io v0.2.0/go.mod h1:1QnQV6X9LNgFKfm8SkOtR9LLaj3bDcsOIeJOOyjbL5E=
|
||||
dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc=
|
||||
dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs=
|
||||
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
||||
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
|
|
@ -12,7 +10,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIter_Iterators_Good(t *testing.T) {
|
||||
func TestIterators(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-c": {MaxRPM: 10},
|
||||
|
|
@ -79,7 +77,7 @@ func TestIter_Iterators_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestIter_IterEarlyBreak_Good(t *testing.T) {
|
||||
func TestIterEarlyBreak(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 10},
|
||||
|
|
@ -112,7 +110,7 @@ func TestIter_IterEarlyBreak_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestIter_CountTokensFull_Ugly(t *testing.T) {
|
||||
func TestCountTokensFull(t *testing.T) {
|
||||
t.Run("empty model is rejected", func(t *testing.T) {
|
||||
_, err := CountTokens(context.Background(), "key", "", "text")
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
495
ratelimit.go
495
ratelimit.go
|
|
@ -1,26 +1,28 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
coreio "dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Provider identifies an LLM provider for quota profiles.
|
||||
//
|
||||
// provider := ProviderOpenAI
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
|
|
@ -45,8 +47,6 @@ const (
|
|||
)
|
||||
|
||||
// ModelQuota defines the rate limits for a specific model.
|
||||
//
|
||||
// quota := ModelQuota{MaxRPM: 60, MaxTPM: 90000, MaxRPD: 1000}
|
||||
type ModelQuota struct {
|
||||
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
|
||||
MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited)
|
||||
|
|
@ -54,18 +54,12 @@ type ModelQuota struct {
|
|||
}
|
||||
|
||||
// ProviderProfile bundles model quotas for a provider.
|
||||
//
|
||||
// profile := ProviderProfile{Provider: ProviderGemini, Models: DefaultProfiles()[ProviderGemini].Models}
|
||||
type ProviderProfile struct {
|
||||
// Provider identifies the provider that owns the profile.
|
||||
Provider Provider `yaml:"provider"`
|
||||
// Models maps model names to quotas.
|
||||
Models map[string]ModelQuota `yaml:"models"`
|
||||
Provider Provider `yaml:"provider"`
|
||||
Models map[string]ModelQuota `yaml:"models"`
|
||||
}
|
||||
|
||||
// Config controls RateLimiter initialisation.
|
||||
//
|
||||
// cfg := Config{Providers: []Provider{ProviderGemini}, FilePath: "/tmp/ratelimits.yaml"}
|
||||
type Config struct {
|
||||
// FilePath overrides the default state file location.
|
||||
// If empty, defaults to ~/.core/ratelimits.yaml.
|
||||
|
|
@ -85,35 +79,23 @@ type Config struct {
|
|||
}
|
||||
|
||||
// TokenEntry records a token usage event.
|
||||
//
|
||||
// entry := TokenEntry{Time: time.Now(), Count: 512}
|
||||
type TokenEntry struct {
|
||||
Time time.Time `yaml:"time"`
|
||||
Count int `yaml:"count"`
|
||||
}
|
||||
|
||||
// UsageStats tracks usage history for a model.
|
||||
//
|
||||
// stats := UsageStats{DayStart: time.Now(), DayCount: 1}
|
||||
type UsageStats struct {
|
||||
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
|
||||
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
|
||||
// DayStart is the start of the rolling 24-hour window.
|
||||
DayStart time.Time `yaml:"day_start"`
|
||||
// DayCount is the number of requests recorded in the rolling 24-hour window.
|
||||
DayCount int `yaml:"day_count"`
|
||||
DayStart time.Time `yaml:"day_start"`
|
||||
DayCount int `yaml:"day_count"`
|
||||
}
|
||||
|
||||
// RateLimiter manages rate limits across multiple models.
|
||||
//
|
||||
// rl, err := New()
|
||||
// if err != nil { /* handle error */ }
|
||||
// defer rl.Close()
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
// Quotas holds the configured per-model limits.
|
||||
Quotas map[string]ModelQuota `yaml:"quotas"`
|
||||
// State holds per-model usage windows.
|
||||
mu sync.RWMutex
|
||||
Quotas map[string]ModelQuota `yaml:"quotas"`
|
||||
State map[string]*UsageStats `yaml:"state"`
|
||||
filePath string
|
||||
sqlite *sqliteStore // non-nil when backend is "sqlite"
|
||||
|
|
@ -121,9 +103,6 @@ type RateLimiter struct {
|
|||
|
||||
// DefaultProfiles returns pre-configured quota profiles for each provider.
|
||||
// Values are based on published rate limits as of Feb 2026.
|
||||
//
|
||||
// profiles := DefaultProfiles()
|
||||
// openAI := profiles[ProviderOpenAI]
|
||||
func DefaultProfiles() map[Provider]ProviderProfile {
|
||||
return map[Provider]ProviderProfile{
|
||||
ProviderGemini: {
|
||||
|
|
@ -167,8 +146,6 @@ func DefaultProfiles() map[Provider]ProviderProfile {
|
|||
|
||||
// New creates a new RateLimiter with Gemini defaults.
|
||||
// This preserves backward compatibility -- existing callers are unaffected.
|
||||
//
|
||||
// rl, err := New()
|
||||
func New() (*RateLimiter, error) {
|
||||
return NewWithConfig(Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
|
|
@ -177,8 +154,6 @@ func New() (*RateLimiter, error) {
|
|||
|
||||
// NewWithConfig creates a RateLimiter from explicit configuration.
|
||||
// If no providers or quotas are specified, Gemini defaults are used.
|
||||
//
|
||||
// rl, err := NewWithConfig(Config{Providers: []Provider{ProviderAnthropic}})
|
||||
func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
||||
backend, err := normaliseBackend(cfg.Backend)
|
||||
if err != nil {
|
||||
|
|
@ -195,8 +170,8 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
|||
|
||||
if backend == backendSQLite {
|
||||
if cfg.FilePath == "" {
|
||||
if err := ensureDir(core.PathDir(filePath)); err != nil {
|
||||
return nil, core.E("ratelimit.NewWithConfig", "mkdir", err)
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
|
||||
}
|
||||
}
|
||||
return NewWithSQLiteConfig(filePath, cfg)
|
||||
|
|
@ -208,8 +183,6 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
|||
}
|
||||
|
||||
// SetQuota sets or updates the quota for a specific model at runtime.
|
||||
//
|
||||
// rl.SetQuota("gpt-4o-mini", ModelQuota{MaxRPM: 60, MaxTPM: 200000})
|
||||
func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -218,8 +191,6 @@ func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
|
|||
|
||||
// AddProvider loads all default quotas for a provider.
|
||||
// Existing quotas for models in the profile are overwritten.
|
||||
//
|
||||
// rl.AddProvider(ProviderOpenAI)
|
||||
func (rl *RateLimiter) AddProvider(provider Provider) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -231,8 +202,6 @@ func (rl *RateLimiter) AddProvider(provider Provider) {
|
|||
}
|
||||
|
||||
// Load reads the state from disk (YAML) or database (SQLite).
|
||||
//
|
||||
// if err := rl.Load(); err != nil { /* handle error */ }
|
||||
func (rl *RateLimiter) Load() error {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -241,20 +210,15 @@ func (rl *RateLimiter) Load() error {
|
|||
return rl.loadSQLite()
|
||||
}
|
||||
|
||||
content, err := readLocalFile(rl.filePath)
|
||||
if core.Is(err, fs.ErrNotExist) {
|
||||
content, err := coreio.Local.Read(rl.filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := yaml.Unmarshal([]byte(content), rl); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ensureMaps(rl)
|
||||
return nil
|
||||
return yaml.Unmarshal([]byte(content), rl)
|
||||
}
|
||||
|
||||
// loadSQLite reads quotas and state from the SQLite backend.
|
||||
|
|
@ -280,8 +244,6 @@ func (rl *RateLimiter) loadSQLite() error {
|
|||
|
||||
// Persist writes a snapshot of the state to disk (YAML) or database (SQLite).
|
||||
// It clones the state under a lock and performs I/O without blocking other callers.
|
||||
//
|
||||
// if err := rl.Persist(); err != nil { /* handle error */ }
|
||||
func (rl *RateLimiter) Persist() error {
|
||||
rl.mu.Lock()
|
||||
quotas := maps.Clone(rl.Quotas)
|
||||
|
|
@ -303,7 +265,7 @@ func (rl *RateLimiter) Persist() error {
|
|||
|
||||
if sqlite != nil {
|
||||
if err := sqlite.saveSnapshot(quotas, state); err != nil {
|
||||
return core.E("ratelimit.Persist", "sqlite snapshot", err)
|
||||
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -318,11 +280,11 @@ func (rl *RateLimiter) Persist() error {
|
|||
State: state,
|
||||
})
|
||||
if err != nil {
|
||||
return core.E("ratelimit.Persist", "marshal", err)
|
||||
return coreerr.E("ratelimit.Persist", "marshal", err)
|
||||
}
|
||||
|
||||
if err := writeLocalFile(filePath, string(data)); err != nil {
|
||||
return core.E("ratelimit.Persist", "write", err)
|
||||
if err := coreio.Local.Write(filePath, string(data)); err != nil {
|
||||
return coreerr.E("ratelimit.Persist", "write", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -372,9 +334,6 @@ func (rl *RateLimiter) prune(model string) {
|
|||
|
||||
// BackgroundPrune starts a goroutine that periodically prunes all model states.
|
||||
// It returns a function to stop the pruner.
|
||||
//
|
||||
// stop := rl.BackgroundPrune(30 * time.Second)
|
||||
// defer stop()
|
||||
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
||||
if interval <= 0 {
|
||||
return func() {}
|
||||
|
|
@ -401,15 +360,53 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
|||
}
|
||||
|
||||
// CanSend checks if a request can be sent without violating limits.
|
||||
//
|
||||
// ok := rl.CanSend("gemini-3-pro-preview", 1200)
|
||||
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||
return rl.Decide(model, estimatedTokens).Allowed
|
||||
if estimatedTokens < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
quota, ok := rl.Quotas[model]
|
||||
if !ok {
|
||||
return true // Unknown models are allowed
|
||||
}
|
||||
|
||||
// Unlimited check
|
||||
if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
rl.prune(model)
|
||||
stats, ok := rl.State[model]
|
||||
if !ok {
|
||||
stats = &UsageStats{DayStart: time.Now()}
|
||||
rl.State[model] = stats
|
||||
}
|
||||
|
||||
// Check RPD
|
||||
if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check RPM
|
||||
if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check TPM
|
||||
if quota.MaxTPM > 0 {
|
||||
currentTokens := totalTokenCount(stats.Tokens)
|
||||
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordUsage records a successful API call.
|
||||
//
|
||||
// rl.RecordUsage("gemini-3-pro-preview", 900, 300)
|
||||
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -429,38 +426,29 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
|
|||
}
|
||||
|
||||
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||
//
|
||||
// err := rl.WaitForCapacity(ctx, "gemini-3-pro-preview", 1200)
|
||||
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||
if tokens < 0 {
|
||||
return core.E("ratelimit.WaitForCapacity", "negative tokens", nil)
|
||||
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
decision := rl.Decide(model, tokens)
|
||||
if decision.Allowed {
|
||||
if rl.CanSend(model, tokens) {
|
||||
return nil
|
||||
}
|
||||
|
||||
sleep := decision.RetryAfter
|
||||
if sleep <= 0 {
|
||||
sleep = time.Second
|
||||
}
|
||||
|
||||
timer := time.NewTimer(sleep)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
timer.Stop()
|
||||
case <-ticker.C:
|
||||
// check again
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears stats for a model (or all if model is empty).
|
||||
//
|
||||
// rl.Reset("gemini-3-pro-preview")
|
||||
func (rl *RateLimiter) Reset(model string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -473,58 +461,17 @@ func (rl *RateLimiter) Reset(model string) {
|
|||
}
|
||||
|
||||
// ModelStats represents a snapshot of usage.
|
||||
//
|
||||
// stats := rl.Stats("gemini-3-pro-preview")
|
||||
type ModelStats struct {
|
||||
// RPM is the current requests-per-minute usage in the sliding window.
|
||||
RPM int
|
||||
// MaxRPM is the configured requests-per-minute limit.
|
||||
MaxRPM int
|
||||
// TPM is the current tokens-per-minute usage in the sliding window.
|
||||
TPM int
|
||||
// MaxTPM is the configured tokens-per-minute limit.
|
||||
MaxTPM int
|
||||
// RPD is the current requests-per-day usage in the rolling 24-hour window.
|
||||
RPD int
|
||||
// MaxRPD is the configured requests-per-day limit.
|
||||
MaxRPD int
|
||||
// DayStart is the start of the current rolling 24-hour window.
|
||||
RPM int
|
||||
MaxRPM int
|
||||
TPM int
|
||||
MaxTPM int
|
||||
RPD int
|
||||
MaxRPD int
|
||||
DayStart time.Time
|
||||
}
|
||||
|
||||
// DecisionCode identifies the reason for an allow or deny outcome from Decide.
|
||||
type DecisionCode string
|
||||
|
||||
const (
|
||||
// DecisionAllowed means the request fits within all configured limits.
|
||||
DecisionAllowed DecisionCode = "ok"
|
||||
// DecisionUnknownModel means the model has no configured quotas and is therefore allowed.
|
||||
DecisionUnknownModel DecisionCode = "unknown_model"
|
||||
// DecisionUnlimited means the model is configured with no limits.
|
||||
DecisionUnlimited DecisionCode = "unlimited"
|
||||
// DecisionInvalidTokens means a negative token estimate was provided.
|
||||
DecisionInvalidTokens DecisionCode = "invalid_tokens"
|
||||
// DecisionRPDLimit means the rolling 24-hour request limit has been reached.
|
||||
DecisionRPDLimit DecisionCode = "rpd_exceeded"
|
||||
// DecisionRPMLimit means the per-minute request limit has been reached.
|
||||
DecisionRPMLimit DecisionCode = "rpm_exceeded"
|
||||
// DecisionTPMLimit means the per-minute token limit would be exceeded.
|
||||
DecisionTPMLimit DecisionCode = "tpm_exceeded"
|
||||
)
|
||||
|
||||
// Decision captures an allow/deny decision with context for agents.
|
||||
// RetryAfter is zero when the request is allowed or when no meaningful wait time exists.
|
||||
type Decision struct {
|
||||
Allowed bool
|
||||
Code DecisionCode
|
||||
Reason string
|
||||
RetryAfter time.Duration
|
||||
Stats ModelStats
|
||||
}
|
||||
|
||||
// Models returns a sorted iterator over all model names tracked by the limiter.
|
||||
//
|
||||
// for model := range rl.Models() { println(model) }
|
||||
func (rl *RateLimiter) Models() iter.Seq[string] {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
|
@ -541,8 +488,6 @@ func (rl *RateLimiter) Models() iter.Seq[string] {
|
|||
}
|
||||
|
||||
// Iter returns a sorted iterator over all model names and their current stats.
|
||||
//
|
||||
// for model, stats := range rl.Iter() { _ = stats; println(model) }
|
||||
func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
|
||||
return func(yield func(string, ModelStats) bool) {
|
||||
stats := rl.AllStats()
|
||||
|
|
@ -555,20 +500,33 @@ func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
|
|||
}
|
||||
|
||||
// Stats returns current stats for a model.
|
||||
//
|
||||
// stats := rl.Stats("gemini-3-pro-preview")
|
||||
func (rl *RateLimiter) Stats(model string) ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.prune(model)
|
||||
|
||||
return rl.snapshotLocked(model)
|
||||
stats := ModelStats{}
|
||||
quota, ok := rl.Quotas[model]
|
||||
if ok {
|
||||
stats.MaxRPM = quota.MaxRPM
|
||||
stats.MaxTPM = quota.MaxTPM
|
||||
stats.MaxRPD = quota.MaxRPD
|
||||
}
|
||||
|
||||
if s, ok := rl.State[model]; ok {
|
||||
stats.RPM = len(s.Requests)
|
||||
stats.RPD = s.DayCount
|
||||
stats.DayStart = s.DayStart
|
||||
for _, t := range s.Tokens {
|
||||
stats.TPM += t.Count
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// AllStats returns stats for all tracked models.
|
||||
//
|
||||
// all := rl.AllStats()
|
||||
func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -586,112 +544,27 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
|||
for m := range result {
|
||||
rl.prune(m)
|
||||
|
||||
result[m] = rl.snapshotLocked(m)
|
||||
ms := ModelStats{}
|
||||
if q, ok := rl.Quotas[m]; ok {
|
||||
ms.MaxRPM = q.MaxRPM
|
||||
ms.MaxTPM = q.MaxTPM
|
||||
ms.MaxRPD = q.MaxRPD
|
||||
}
|
||||
if s, ok := rl.State[m]; ok && s != nil {
|
||||
ms.RPM = len(s.Requests)
|
||||
ms.RPD = s.DayCount
|
||||
ms.DayStart = s.DayStart
|
||||
ms.TPM = totalTokenCount(s.Tokens)
|
||||
}
|
||||
result[m] = ms
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Decide returns structured allow/deny information for an estimated request.
|
||||
// It never records usage; call RecordUsage after a successful decision.
|
||||
func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
decision := Decision{}
|
||||
|
||||
if estimatedTokens < 0 {
|
||||
decision.Allowed = false
|
||||
decision.Code = DecisionInvalidTokens
|
||||
decision.Reason = "estimated tokens must be non-negative"
|
||||
decision.Stats = rl.snapshotLocked(model)
|
||||
return decision
|
||||
}
|
||||
|
||||
quota, ok := rl.Quotas[model]
|
||||
if !ok {
|
||||
decision.Allowed = true
|
||||
decision.Code = DecisionUnknownModel
|
||||
decision.Reason = "model has no configured quota"
|
||||
decision.Stats = rl.snapshotLocked(model)
|
||||
return decision
|
||||
}
|
||||
|
||||
if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 {
|
||||
decision.Allowed = true
|
||||
decision.Code = DecisionUnlimited
|
||||
decision.Reason = "all limits are unlimited"
|
||||
decision.Stats = rl.snapshotLocked(model)
|
||||
return decision
|
||||
}
|
||||
|
||||
rl.prune(model)
|
||||
stats, ok := rl.State[model]
|
||||
if !ok || stats == nil {
|
||||
stats = &UsageStats{DayStart: now}
|
||||
rl.State[model] = stats
|
||||
}
|
||||
|
||||
decision.Stats = rl.snapshotLocked(model)
|
||||
|
||||
if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD {
|
||||
decision.Code = DecisionRPDLimit
|
||||
decision.Reason = "daily request limit reached"
|
||||
decision.RetryAfter = nonNegativeDuration(stats.DayStart.Add(24 * time.Hour).Sub(now))
|
||||
return decision
|
||||
}
|
||||
|
||||
if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM {
|
||||
decision.Code = DecisionRPMLimit
|
||||
decision.Reason = "per-minute request limit reached"
|
||||
if len(stats.Requests) > 0 {
|
||||
decision.RetryAfter = nonNegativeDuration(stats.Requests[0].Add(time.Minute).Sub(now))
|
||||
}
|
||||
return decision
|
||||
}
|
||||
|
||||
if quota.MaxTPM > 0 {
|
||||
currentTokens := totalTokenCount(stats.Tokens)
|
||||
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
|
||||
decision.Code = DecisionTPMLimit
|
||||
decision.Reason = "per-minute token limit reached"
|
||||
decision.RetryAfter = retryAfterForTokens(now, stats.Tokens, quota.MaxTPM, estimatedTokens)
|
||||
return decision
|
||||
}
|
||||
}
|
||||
|
||||
decision.Allowed = true
|
||||
decision.Code = DecisionAllowed
|
||||
decision.Reason = "within quota"
|
||||
return decision
|
||||
}
|
||||
|
||||
// snapshotLocked builds ModelStats for the provided model.
|
||||
// Caller must hold rl.mu.
|
||||
func (rl *RateLimiter) snapshotLocked(model string) ModelStats {
|
||||
stats := ModelStats{}
|
||||
|
||||
if q, ok := rl.Quotas[model]; ok {
|
||||
stats.MaxRPM = q.MaxRPM
|
||||
stats.MaxTPM = q.MaxTPM
|
||||
stats.MaxRPD = q.MaxRPD
|
||||
}
|
||||
|
||||
if s, ok := rl.State[model]; ok && s != nil {
|
||||
stats.RPM = len(s.Requests)
|
||||
stats.RPD = s.DayCount
|
||||
stats.DayStart = s.DayStart
|
||||
stats.TPM = totalTokenCount(s.Tokens)
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// NewWithSQLite creates a SQLite-backed RateLimiter with Gemini defaults.
|
||||
// The database is created at dbPath if it does not exist. Use Close() to
|
||||
// release the database connection when finished.
|
||||
//
|
||||
// rl, err := NewWithSQLite("/tmp/ratelimits.db")
|
||||
func NewWithSQLite(dbPath string) (*RateLimiter, error) {
|
||||
return NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
|
|
@ -701,8 +574,6 @@ func NewWithSQLite(dbPath string) (*RateLimiter, error) {
|
|||
// NewWithSQLiteConfig creates a SQLite-backed RateLimiter with custom config.
|
||||
// The Backend field in cfg is ignored (always "sqlite"). Use Close() to
|
||||
// release the database connection when finished.
|
||||
//
|
||||
// rl, err := NewWithSQLiteConfig("/tmp/ratelimits.db", Config{Providers: []Provider{ProviderOpenAI}})
|
||||
func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
if err != nil {
|
||||
|
|
@ -717,8 +588,6 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
|||
// Close releases resources held by the RateLimiter. For YAML-backed
|
||||
// limiters this is a no-op. For SQLite-backed limiters it closes the
|
||||
// database connection.
|
||||
//
|
||||
// defer rl.Close()
|
||||
func (rl *RateLimiter) Close() error {
|
||||
if rl.sqlite != nil {
|
||||
return rl.sqlite.close()
|
||||
|
|
@ -729,18 +598,16 @@ func (rl *RateLimiter) Close() error {
|
|||
// MigrateYAMLToSQLite reads state from a YAML file and writes it to a new
|
||||
// SQLite database. Both quotas and usage state are migrated. The SQLite
|
||||
// database is created if it does not exist.
|
||||
//
|
||||
// err := MigrateYAMLToSQLite("ratelimits.yaml", "ratelimits.db")
|
||||
func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
||||
// Load from YAML.
|
||||
content, err := readLocalFile(yamlPath)
|
||||
content, err := coreio.Local.Read(yamlPath)
|
||||
if err != nil {
|
||||
return core.E("ratelimit.MigrateYAMLToSQLite", "read", err)
|
||||
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "read", err)
|
||||
}
|
||||
|
||||
var rl RateLimiter
|
||||
if err := yaml.Unmarshal([]byte(content), &rl); err != nil {
|
||||
return core.E("ratelimit.MigrateYAMLToSQLite", "unmarshal", err)
|
||||
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "unmarshal", err)
|
||||
}
|
||||
|
||||
// Write to SQLite.
|
||||
|
|
@ -757,8 +624,6 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
|||
}
|
||||
|
||||
// CountTokens calls the Google API to count tokens for a prompt.
|
||||
//
|
||||
// tokens, err := CountTokens(ctx, apiKey, "gemini-3-pro-preview", prompt)
|
||||
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
||||
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
|
||||
}
|
||||
|
|
@ -766,7 +631,7 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
|||
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
|
||||
requestURL, err := countTokensURL(baseURL, model)
|
||||
if err != nil {
|
||||
return 0, core.E("ratelimit.CountTokens", "build url", err)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
|
|
@ -779,14 +644,14 @@ func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, ap
|
|||
},
|
||||
}
|
||||
|
||||
jsonBody := core.JSONMarshal(reqBody)
|
||||
if !jsonBody.OK {
|
||||
return 0, core.E("ratelimit.CountTokens", "marshal request", resultError(jsonBody))
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, core.NewReader(string(jsonBody.Value.([]byte))))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return 0, core.E("ratelimit.CountTokens", "new request", err)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
|
@ -797,29 +662,23 @@ func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, ap
|
|||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, core.E("ratelimit.CountTokens", "do request", err)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
|
||||
if err != nil {
|
||||
return 0, core.E("ratelimit.CountTokens", "read error body", err)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
|
||||
}
|
||||
return 0, core.E("ratelimit.CountTokens", core.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
|
||||
}
|
||||
|
||||
body, err := readLimitedBody(resp.Body, countTokensSuccessBodyLimit)
|
||||
if err != nil {
|
||||
return 0, core.E("ratelimit.CountTokens", "decode response", err)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
decode := core.JSONUnmarshalString(body, &result)
|
||||
if !decode.OK {
|
||||
return 0, core.E("ratelimit.CountTokens", "decode response", resultError(decode))
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
|
||||
}
|
||||
|
||||
return result.TotalTokens, nil
|
||||
|
|
@ -834,15 +693,6 @@ func newConfiguredRateLimiter(cfg Config) *RateLimiter {
|
|||
return rl
|
||||
}
|
||||
|
||||
func ensureMaps(rl *RateLimiter) {
|
||||
if rl.Quotas == nil {
|
||||
rl.Quotas = make(map[string]ModelQuota)
|
||||
}
|
||||
if rl.State == nil {
|
||||
rl.State = make(map[string]*UsageStats)
|
||||
}
|
||||
}
|
||||
|
||||
func applyConfig(rl *RateLimiter, cfg Config) {
|
||||
profiles := DefaultProfiles()
|
||||
providers := cfg.Providers
|
||||
|
|
@ -861,20 +711,20 @@ func applyConfig(rl *RateLimiter, cfg Config) {
|
|||
}
|
||||
|
||||
func normaliseBackend(backend string) (string, error) {
|
||||
switch core.Lower(core.Trim(backend)) {
|
||||
switch strings.ToLower(strings.TrimSpace(backend)) {
|
||||
case "", backendYAML:
|
||||
return backendYAML, nil
|
||||
case backendSQLite:
|
||||
return backendSQLite, nil
|
||||
default:
|
||||
return "", core.E("ratelimit.NewWithConfig", core.Sprintf("unknown backend %q", backend), nil)
|
||||
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
|
||||
}
|
||||
}
|
||||
|
||||
func defaultStatePath(backend string) (string, error) {
|
||||
home := currentHomeDir()
|
||||
if home == "" {
|
||||
return "", core.E("ratelimit.defaultStatePath", "home dir unavailable", nil)
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := defaultYAMLStateFile
|
||||
|
|
@ -882,16 +732,7 @@ func defaultStatePath(backend string) (string, error) {
|
|||
fileName = defaultSQLiteStateFile
|
||||
}
|
||||
|
||||
return core.Path(home, defaultStateDirName, fileName), nil
|
||||
}
|
||||
|
||||
func currentHomeDir() string {
|
||||
for _, key := range []string{"CORE_HOME", "HOME", "home", "USERPROFILE"} {
|
||||
if value := core.Trim(core.Env(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
return filepath.Join(home, defaultStateDirName, fileName), nil
|
||||
}
|
||||
|
||||
func safeTokenSum(a, b int) int {
|
||||
|
|
@ -919,40 +760,9 @@ func safeTokenTotal(tokens []TokenEntry) int {
|
|||
return total
|
||||
}
|
||||
|
||||
func retryAfterForTokens(now time.Time, tokens []TokenEntry, maxTPM, estimatedTokens int) time.Duration {
|
||||
if maxTPM <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
deficit := totalTokenCount(tokens) + estimatedTokens - maxTPM
|
||||
if deficit <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := deficit
|
||||
for _, entry := range tokens {
|
||||
if entry.Count < 0 {
|
||||
continue
|
||||
}
|
||||
remaining -= entry.Count
|
||||
if remaining <= 0 {
|
||||
return nonNegativeDuration(entry.Time.Add(time.Minute).Sub(now))
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func nonNegativeDuration(value time.Duration) time.Duration {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func countTokensURL(baseURL, model string) (string, error) {
|
||||
if core.Trim(model) == "" {
|
||||
return "", core.E("ratelimit.countTokensURL", "empty model", nil)
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return "", fmt.Errorf("empty model")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(baseURL)
|
||||
|
|
@ -960,10 +770,10 @@ func countTokensURL(baseURL, model string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", core.E("ratelimit.countTokensURL", "invalid base url", nil)
|
||||
return "", fmt.Errorf("invalid base url")
|
||||
}
|
||||
|
||||
return core.Concat(core.TrimSuffix(parsed.String(), "/"), "/v1beta/models/", url.PathEscape(model), ":countTokens"), nil
|
||||
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
|
||||
}
|
||||
|
||||
func readLimitedBody(r io.Reader, limit int64) (string, error) {
|
||||
|
|
@ -983,40 +793,3 @@ func readLimitedBody(r io.Reader, limit int64) (string, error) {
|
|||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func readLocalFile(path string) (string, error) {
|
||||
var fs core.Fs
|
||||
result := fs.Read(path)
|
||||
if !result.OK {
|
||||
return "", resultError(result)
|
||||
}
|
||||
|
||||
content, ok := result.Value.(string)
|
||||
if !ok {
|
||||
return "", core.E("ratelimit.readLocalFile", "read returned non-string", nil)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func writeLocalFile(path, content string) error {
|
||||
var fs core.Fs
|
||||
return resultError(fs.Write(path, content))
|
||||
}
|
||||
|
||||
func ensureDir(path string) error {
|
||||
var fs core.Fs
|
||||
return resultError(fs.EnsureDir(path))
|
||||
}
|
||||
|
||||
func resultError(result core.Result) error {
|
||||
if result.OK {
|
||||
return nil
|
||||
}
|
||||
if err, ok := result.Value.(error); ok {
|
||||
return err
|
||||
}
|
||||
if result.Value == nil {
|
||||
return nil
|
||||
}
|
||||
return core.E("ratelimit.resultError", core.Sprint(result.Value), nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,95 +1,29 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testPath(parts ...string) string {
|
||||
return core.Path(parts...)
|
||||
}
|
||||
|
||||
func pathExists(path string) bool {
|
||||
var fs core.Fs
|
||||
return fs.Exists(path)
|
||||
}
|
||||
|
||||
func writeTestFile(tb testing.TB, path, content string) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, writeLocalFile(path, content))
|
||||
}
|
||||
|
||||
func ensureTestDir(tb testing.TB, path string) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, ensureDir(path))
|
||||
}
|
||||
|
||||
func setPathMode(tb testing.TB, path string, mode uint32) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, syscall.Chmod(path, mode))
|
||||
}
|
||||
|
||||
func overwriteTestFile(tb testing.TB, path, content string) {
|
||||
tb.Helper()
|
||||
|
||||
var fs core.Fs
|
||||
writer := fs.Create(path)
|
||||
require.NoError(tb, resultError(writer))
|
||||
require.NoError(tb, resultError(core.WriteAll(writer.Value, content)))
|
||||
}
|
||||
|
||||
func isRootUser() bool {
|
||||
return syscall.Geteuid() == 0
|
||||
}
|
||||
|
||||
func repeatString(part string, count int) string {
|
||||
builder := core.NewBuilder()
|
||||
for i := 0; i < count; i++ {
|
||||
builder.WriteString(part)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func substringCount(s, substr string) int {
|
||||
if substr == "" {
|
||||
return 0
|
||||
}
|
||||
return len(core.Split(s, substr)) - 1
|
||||
}
|
||||
|
||||
func decodeJSONBody(tb testing.TB, r io.Reader, target any) {
|
||||
tb.Helper()
|
||||
|
||||
data, err := io.ReadAll(r)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, resultError(core.JSONUnmarshal(data, target)))
|
||||
}
|
||||
|
||||
func writeJSONBody(tb testing.TB, w io.Writer, value any) {
|
||||
tb.Helper()
|
||||
|
||||
_, err := io.WriteString(w, core.JSONMarshalString(value))
|
||||
require.NoError(tb, err)
|
||||
}
|
||||
|
||||
// newTestLimiter returns a RateLimiter with file path set to a temp directory.
|
||||
func newTestLimiter(t *testing.T) *RateLimiter {
|
||||
t.Helper()
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
rl.filePath = testPath(t.TempDir(), "ratelimits.yaml")
|
||||
rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml")
|
||||
return rl
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +41,7 @@ func (errReader) Read([]byte) (int, error) {
|
|||
|
||||
// --- Phase 0: CanSend boundary conditions ---
|
||||
|
||||
func TestRatelimit_CanSend_Good(t *testing.T) {
|
||||
func TestCanSend(t *testing.T) {
|
||||
t.Run("fresh state allows send", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "test-model"
|
||||
|
|
@ -253,133 +187,9 @@ func TestRatelimit_CanSend_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Decide surface area ---
|
||||
|
||||
func TestRatelimit_Decide_Good(t *testing.T) {
|
||||
t.Run("unknown model remains allowed with unknown code", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
|
||||
decision := rl.Decide("unknown-model", 50)
|
||||
|
||||
assert.True(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionUnknownModel, decision.Code)
|
||||
assert.Zero(t, decision.RetryAfter)
|
||||
})
|
||||
|
||||
t.Run("unlimited quota reports unlimited decision", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "unlimited"
|
||||
rl.Quotas[model] = ModelQuota{}
|
||||
|
||||
decision := rl.Decide(model, 100)
|
||||
|
||||
assert.True(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionUnlimited, decision.Code)
|
||||
assert.Equal(t, 0, decision.Stats.MaxRPM)
|
||||
assert.Equal(t, 0, decision.Stats.MaxTPM)
|
||||
assert.Equal(t, 0, decision.Stats.MaxRPD)
|
||||
})
|
||||
|
||||
t.Run("rpd limit returns retry window", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "rpd-limit"
|
||||
now := time.Now()
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 2}
|
||||
rl.State[model] = &UsageStats{DayStart: now.Add(-23 * time.Hour), DayCount: 2}
|
||||
|
||||
decision := rl.Decide(model, 10)
|
||||
|
||||
assert.False(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionRPDLimit, decision.Code)
|
||||
assert.InDelta(t, time.Hour.Seconds(), decision.RetryAfter.Seconds(), 2)
|
||||
assert.Equal(t, 2, decision.Stats.MaxRPD)
|
||||
assert.Equal(t, 2, decision.Stats.RPD)
|
||||
})
|
||||
|
||||
t.Run("rpm limit includes retry-after estimate", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "rpm-limit"
|
||||
now := time.Now()
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000, MaxRPD: 5}
|
||||
rl.State[model] = &UsageStats{
|
||||
Requests: []time.Time{now.Add(-10 * time.Second)},
|
||||
Tokens: []TokenEntry{{Time: now.Add(-10 * time.Second), Count: 10}},
|
||||
DayStart: now,
|
||||
DayCount: 1,
|
||||
}
|
||||
|
||||
decision := rl.Decide(model, 5)
|
||||
|
||||
assert.False(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionRPMLimit, decision.Code)
|
||||
assert.InDelta(t, 50, decision.RetryAfter.Seconds(), 1)
|
||||
})
|
||||
|
||||
t.Run("tpm limit surfaces earliest expiry", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "tpm-limit"
|
||||
now := time.Now()
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10}
|
||||
rl.State[model] = &UsageStats{
|
||||
Requests: []time.Time{now.Add(-30 * time.Second)},
|
||||
Tokens: []TokenEntry{
|
||||
{Time: now.Add(-50 * time.Second), Count: 70},
|
||||
{Time: now.Add(-10 * time.Second), Count: 20},
|
||||
},
|
||||
DayStart: now,
|
||||
DayCount: 2,
|
||||
}
|
||||
|
||||
decision := rl.Decide(model, 20)
|
||||
|
||||
assert.False(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionTPMLimit, decision.Code)
|
||||
assert.InDelta(t, 10, decision.RetryAfter.Seconds(), 1)
|
||||
})
|
||||
|
||||
t.Run("allowed decision carries stats snapshot", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "decide-allowed"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 200, MaxRPD: 3}
|
||||
now := time.Now()
|
||||
rl.State[model] = &UsageStats{
|
||||
Requests: []time.Time{now.Add(-5 * time.Second)},
|
||||
Tokens: []TokenEntry{{Time: now.Add(-5 * time.Second), Count: 30}},
|
||||
DayStart: now,
|
||||
DayCount: 1,
|
||||
}
|
||||
|
||||
decision := rl.Decide(model, 20)
|
||||
|
||||
assert.True(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionAllowed, decision.Code)
|
||||
assert.Equal(t, 1, decision.Stats.RPM)
|
||||
assert.Equal(t, 30, decision.Stats.TPM)
|
||||
assert.Equal(t, 1, decision.Stats.RPD)
|
||||
assert.Equal(t, 5, decision.Stats.MaxRPM)
|
||||
assert.Equal(t, 200, decision.Stats.MaxTPM)
|
||||
assert.Equal(t, 3, decision.Stats.MaxRPD)
|
||||
})
|
||||
|
||||
t.Run("negative estimate returns invalid decision", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "neg"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 50, MaxRPD: 5}
|
||||
|
||||
decision := rl.Decide(model, -5)
|
||||
|
||||
assert.False(t, decision.Allowed)
|
||||
assert.Equal(t, DecisionInvalidTokens, decision.Code)
|
||||
assert.Zero(t, decision.RetryAfter)
|
||||
require.Contains(t, rl.State, model)
|
||||
require.NotNil(t, rl.State[model])
|
||||
assert.Equal(t, 0, rl.State[model].DayCount)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Sliding window / prune tests ---
|
||||
|
||||
func TestRatelimit_Prune_Good(t *testing.T) {
|
||||
func TestPrune(t *testing.T) {
|
||||
t.Run("removes old entries", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "test-prune"
|
||||
|
|
@ -494,7 +304,7 @@ func TestRatelimit_Prune_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: RecordUsage ---
|
||||
|
||||
func TestRatelimit_RecordUsage_Good(t *testing.T) {
|
||||
func TestRecordUsage(t *testing.T) {
|
||||
t.Run("records into fresh state", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "record-fresh"
|
||||
|
|
@ -565,7 +375,7 @@ func TestRatelimit_RecordUsage_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Reset ---
|
||||
|
||||
func TestRatelimit_Reset_Good(t *testing.T) {
|
||||
func TestReset(t *testing.T) {
|
||||
t.Run("reset single model", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.RecordUsage("model-a", 10, 10)
|
||||
|
|
@ -599,7 +409,7 @@ func TestRatelimit_Reset_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: WaitForCapacity ---
|
||||
|
||||
func TestRatelimit_WaitForCapacity_Good(t *testing.T) {
|
||||
func TestWaitForCapacity(t *testing.T) {
|
||||
t.Run("context cancelled returns error", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "wait-cancel"
|
||||
|
|
@ -657,7 +467,7 @@ func TestRatelimit_WaitForCapacity_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_NilUsageStats_Ugly(t *testing.T) {
|
||||
func TestNilUsageStats(t *testing.T) {
|
||||
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "nil-cansend"
|
||||
|
|
@ -704,7 +514,7 @@ func TestRatelimit_NilUsageStats_Ugly(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Stats ---
|
||||
|
||||
func TestRatelimit_Stats_Good(t *testing.T) {
|
||||
func TestStats(t *testing.T) {
|
||||
t.Run("returns stats for known model with usage", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "stats-test"
|
||||
|
|
@ -744,7 +554,7 @@ func TestRatelimit_Stats_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: AllStats ---
|
||||
|
||||
func TestRatelimit_AllStats_Good(t *testing.T) {
|
||||
func TestAllStats(t *testing.T) {
|
||||
t.Run("includes all default quotas plus state-only models", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.RecordUsage("gemini-3-pro-preview", 1000, 500)
|
||||
|
|
@ -802,10 +612,10 @@ func TestRatelimit_AllStats_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Persist and Load ---
|
||||
|
||||
func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
||||
func TestPersistAndLoad(t *testing.T) {
|
||||
t.Run("round-trip preserves state", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "ratelimits.yaml")
|
||||
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||
|
||||
rl1, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -828,7 +638,7 @@ func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("load from non-existent file is not an error", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = testPath(t.TempDir(), "does-not-exist.yaml")
|
||||
rl.filePath = filepath.Join(t.TempDir(), "does-not-exist.yaml")
|
||||
|
||||
err := rl.Load()
|
||||
assert.NoError(t, err, "loading non-existent file should not error")
|
||||
|
|
@ -836,8 +646,8 @@ func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("load from corrupt YAML returns error", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "corrupt.yaml")
|
||||
writeTestFile(t, path, "{{{{invalid yaml!!!!")
|
||||
path := filepath.Join(tmpDir, "corrupt.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("{{{{invalid yaml!!!!"), 0644))
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -847,13 +657,13 @@ func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("load from unreadable file returns error", func(t *testing.T) {
|
||||
if isRootUser() {
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod 000 does not restrict root")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "unreadable.yaml")
|
||||
writeTestFile(t, path, "quotas: {}")
|
||||
setPathMode(t, path, 0o000)
|
||||
path := filepath.Join(tmpDir, "unreadable.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("quotas: {}"), 0644))
|
||||
require.NoError(t, os.Chmod(path, 0000))
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -862,12 +672,12 @@ func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
|||
assert.Error(t, err, "unreadable file should produce an error")
|
||||
|
||||
// Clean up permissions for temp dir cleanup
|
||||
_ = syscall.Chmod(path, 0o644)
|
||||
_ = os.Chmod(path, 0644)
|
||||
})
|
||||
|
||||
t.Run("persist to nested non-existent directory creates it", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "nested", "deep", "ratelimits.yaml")
|
||||
path := filepath.Join(tmpDir, "nested", "deep", "ratelimits.yaml")
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -876,32 +686,32 @@ func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
|||
err := rl.Persist()
|
||||
assert.NoError(t, err, "should create nested directories")
|
||||
|
||||
assert.True(t, pathExists(path), "file should exist")
|
||||
_, statErr := os.Stat(path)
|
||||
assert.NoError(t, statErr, "file should exist")
|
||||
})
|
||||
|
||||
t.Run("persist to unwritable directory returns error", func(t *testing.T) {
|
||||
if isRootUser() {
|
||||
if os.Getuid() == 0 {
|
||||
t.Skip("chmod 0555 does not restrict root")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
unwritable := testPath(tmpDir, "readonly")
|
||||
ensureTestDir(t, unwritable)
|
||||
setPathMode(t, unwritable, 0o555)
|
||||
unwritable := filepath.Join(tmpDir, "readonly")
|
||||
require.NoError(t, os.MkdirAll(unwritable, 0555))
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = testPath(unwritable, "sub", "ratelimits.yaml")
|
||||
rl.filePath = filepath.Join(unwritable, "sub", "ratelimits.yaml")
|
||||
|
||||
err := rl.Persist()
|
||||
assert.Error(t, err, "should fail when directory is unwritable")
|
||||
|
||||
// Clean up
|
||||
_ = syscall.Chmod(unwritable, 0o755)
|
||||
_ = os.Chmod(unwritable, 0755)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Default quotas ---
|
||||
|
||||
func TestRatelimit_DefaultQuotas_Good(t *testing.T) {
|
||||
func TestDefaultQuotas(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
|
||||
tests := []struct {
|
||||
|
|
@ -930,7 +740,7 @@ func TestRatelimit_DefaultQuotas_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Concurrent access (race test) ---
|
||||
|
||||
func TestRatelimit_ConcurrentAccess_Good(t *testing.T) {
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "concurrent-test"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 10000}
|
||||
|
|
@ -956,7 +766,7 @@ func TestRatelimit_ConcurrentAccess_Good(t *testing.T) {
|
|||
assert.Equal(t, expected, stats.RPD, "all recordings should be counted")
|
||||
}
|
||||
|
||||
func TestRatelimit_ConcurrentResetAndRecord_Ugly(t *testing.T) {
|
||||
func TestConcurrentResetAndRecord(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "concurrent-reset"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000}
|
||||
|
|
@ -994,7 +804,7 @@ func TestRatelimit_ConcurrentResetAndRecord_Ugly(t *testing.T) {
|
|||
// No assertion needed -- if we get here without -race flagging, mutex is sound
|
||||
}
|
||||
|
||||
func TestRatelimit_BackgroundPrune_Good(t *testing.T) {
|
||||
func TestBackgroundPrune(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "prune-me"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 100}
|
||||
|
|
@ -1033,7 +843,7 @@ func TestRatelimit_BackgroundPrune_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0: CountTokens (with mock HTTP server) ---
|
||||
|
||||
func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
||||
func TestCountTokens(t *testing.T) {
|
||||
t.Run("successful token count", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
|
@ -1048,13 +858,13 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
}
|
||||
decodeJSONBody(t, r.Body, &body)
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||
require.Len(t, body.Contents, 1)
|
||||
require.Len(t, body.Contents[0].Parts, 1)
|
||||
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 42})
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -1068,7 +878,7 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
|
||||
assert.Empty(t, r.URL.RawQuery)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 7})
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -1078,10 +888,10 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("API error body is truncated", func(t *testing.T) {
|
||||
largeBody := repeatString("x", countTokensErrorBodyLimit+256)
|
||||
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := io.WriteString(w, largeBody)
|
||||
_, err := fmt.Fprint(w, largeBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
|
@ -1089,7 +899,7 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "api error status 401")
|
||||
assert.True(t, substringCount(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||
assert.Contains(t, err.Error(), "...")
|
||||
})
|
||||
|
||||
|
|
@ -1143,7 +953,7 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 11})
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -1159,8 +969,8 @@ func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_PersistSkipsNilState_Good(t *testing.T) {
|
||||
path := testPath(t.TempDir(), "nil-state.yaml")
|
||||
func TestPersistSkipsNilState(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "nil-state.yaml")
|
||||
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1176,7 +986,7 @@ func TestRatelimit_PersistSkipsNilState_Good(t *testing.T) {
|
|||
assert.NotContains(t, rl2.State, "nil-model")
|
||||
}
|
||||
|
||||
func TestRatelimit_TokenTotals_Good(t *testing.T) {
|
||||
func TestTokenTotals(t *testing.T) {
|
||||
maxInt := int(^uint(0) >> 1)
|
||||
|
||||
assert.Equal(t, 25, safeTokenSum(-100, 25))
|
||||
|
|
@ -1246,7 +1056,7 @@ func BenchmarkCanSendConcurrent(b *testing.B) {
|
|||
|
||||
// --- Phase 1: Provider profiles and NewWithConfig ---
|
||||
|
||||
func TestRatelimit_DefaultProfiles_Good(t *testing.T) {
|
||||
func TestDefaultProfiles(t *testing.T) {
|
||||
profiles := DefaultProfiles()
|
||||
|
||||
t.Run("contains all four providers", func(t *testing.T) {
|
||||
|
|
@ -1287,10 +1097,10 @@ func TestRatelimit_DefaultProfiles_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
||||
func TestNewWithConfig(t *testing.T) {
|
||||
t.Run("empty config defaults to Gemini", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -1300,7 +1110,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("single provider loads only its models", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderOpenAI},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1314,7 +1124,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("multiple providers merge models", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1330,7 +1140,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("explicit quotas override provider defaults", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderGemini},
|
||||
Quotas: map[string]ModelQuota{
|
||||
"gemini-3-pro-preview": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777},
|
||||
|
|
@ -1346,7 +1156,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("explicit quotas without providers", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Quotas: map[string]ModelQuota{
|
||||
"my-custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
|
||||
},
|
||||
|
|
@ -1359,7 +1169,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("custom file path is respected", func(t *testing.T) {
|
||||
customPath := testPath(t.TempDir(), "custom", "limits.yaml")
|
||||
customPath := filepath.Join(t.TempDir(), "custom", "limits.yaml")
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: customPath,
|
||||
Providers: []Provider{ProviderLocal},
|
||||
|
|
@ -1369,12 +1179,13 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
rl.RecordUsage("test", 1, 1)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
assert.True(t, pathExists(customPath), "file should be created at custom path")
|
||||
_, statErr := os.Stat(customPath)
|
||||
assert.NoError(t, statErr, "file should be created at custom path")
|
||||
})
|
||||
|
||||
t.Run("unknown provider is silently skipped", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{"nonexistent-provider"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1383,7 +1194,7 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
t.Run("local provider with custom quotas", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderLocal},
|
||||
Quotas: map[string]ModelQuota{
|
||||
"llama-3.3-70b": {MaxRPM: 5, MaxTPM: 50000, MaxRPD: 0},
|
||||
|
|
@ -1413,11 +1224,11 @@ func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
|||
|
||||
rl, err := NewWithConfig(Config{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, testPath(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_NewBackwardCompatibility_Good(t *testing.T) {
|
||||
func TestNewBackwardCompatibility(t *testing.T) {
|
||||
// New() should produce the exact same result as before Phase 1
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1440,7 +1251,7 @@ func TestRatelimit_NewBackwardCompatibility_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRatelimit_SetQuota_Good(t *testing.T) {
|
||||
func TestSetQuota(t *testing.T) {
|
||||
t.Run("adds new model quota", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.SetQuota("custom-model", ModelQuota{MaxRPM: 42, MaxTPM: 9999, MaxRPD: 100})
|
||||
|
|
@ -1468,7 +1279,7 @@ func TestRatelimit_SetQuota_Good(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
model := core.Sprintf("model-%d", n)
|
||||
model := fmt.Sprintf("model-%d", n)
|
||||
rl.SetQuota(model, ModelQuota{MaxRPM: n, MaxTPM: n * 100, MaxRPD: n * 10})
|
||||
}(i)
|
||||
}
|
||||
|
|
@ -1478,7 +1289,7 @@ func TestRatelimit_SetQuota_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_AddProvider_Good(t *testing.T) {
|
||||
func TestAddProvider(t *testing.T) {
|
||||
t.Run("adds OpenAI models to existing limiter", func(t *testing.T) {
|
||||
rl := newTestLimiter(t) // starts with Gemini defaults
|
||||
geminiCount := len(rl.Quotas)
|
||||
|
|
@ -1540,7 +1351,7 @@ func TestRatelimit_AddProvider_Good(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRatelimit_ProviderConstants_Good(t *testing.T) {
|
||||
func TestProviderConstants(t *testing.T) {
|
||||
// Verify the string values are stable (they may be used in YAML configs)
|
||||
assert.Equal(t, Provider("gemini"), ProviderGemini)
|
||||
assert.Equal(t, Provider("openai"), ProviderOpenAI)
|
||||
|
|
@ -1550,7 +1361,7 @@ func TestRatelimit_ProviderConstants_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 0 addendum: Additional concurrent and multi-model race tests ---
|
||||
|
||||
func TestRatelimit_ConcurrentMultipleModels_Good(t *testing.T) {
|
||||
func TestConcurrentMultipleModels(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
models := []string{"model-a", "model-b", "model-c", "model-d", "model-e"}
|
||||
for _, m := range models {
|
||||
|
|
@ -1580,9 +1391,9 @@ func TestRatelimit_ConcurrentMultipleModels_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRatelimit_ConcurrentPersistAndLoad_Ugly(t *testing.T) {
|
||||
func TestConcurrentPersistAndLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "concurrent.yaml")
|
||||
path := filepath.Join(tmpDir, "concurrent.yaml")
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -1614,7 +1425,7 @@ func TestRatelimit_ConcurrentPersistAndLoad_Ugly(t *testing.T) {
|
|||
// No panics or data races = pass
|
||||
}
|
||||
|
||||
func TestRatelimit_ConcurrentAllStatsAndRecordUsage_Good(t *testing.T) {
|
||||
func TestConcurrentAllStatsAndRecordUsage(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
models := []string{"stats-a", "stats-b", "stats-c"}
|
||||
for _, m := range models {
|
||||
|
|
@ -1645,7 +1456,7 @@ func TestRatelimit_ConcurrentAllStatsAndRecordUsage_Good(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRatelimit_ConcurrentWaitForCapacityAndRecordUsage_Good(t *testing.T) {
|
||||
func TestConcurrentWaitForCapacityAndRecordUsage(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "race-wait"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 10000000, MaxRPD: 10000}
|
||||
|
|
@ -1742,7 +1553,7 @@ func BenchmarkAllStats(b *testing.B) {
|
|||
|
||||
func BenchmarkPersist(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
path := testPath(tmpDir, "bench.yaml")
|
||||
path := filepath.Join(tmpDir, "bench.yaml")
|
||||
|
||||
rl, _ := New()
|
||||
rl.filePath = path
|
||||
|
|
@ -1763,10 +1574,10 @@ func BenchmarkPersist(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRatelimit_EndToEndMultiProvider_Good(t *testing.T) {
|
||||
func TestEndToEndMultiProvider(t *testing.T) {
|
||||
// Simulate a real-world scenario: limiter for both Gemini and Anthropic
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "multi.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "multi.yaml"),
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
154
specs/RFC.md
154
specs/RFC.md
|
|
@ -1,154 +0,0 @@
|
|||
# ratelimit
|
||||
**Import:** `dappco.re/go/core/go-ratelimit`
|
||||
**Files:** 2
|
||||
|
||||
## Types
|
||||
|
||||
### `Provider`
|
||||
`type Provider string`
|
||||
|
||||
`Provider` identifies an LLM provider used to select built-in quota profiles. The package defines four exported provider values: `ProviderGemini`, `ProviderOpenAI`, `ProviderAnthropic`, and `ProviderLocal`.
|
||||
|
||||
### `ModelQuota`
|
||||
`type ModelQuota struct`
|
||||
|
||||
`ModelQuota` defines the rate limits for a single model. A value of `0` means the corresponding limit is unlimited.
|
||||
|
||||
- `MaxRPM int`: requests per minute.
|
||||
- `MaxTPM int`: tokens per minute.
|
||||
- `MaxRPD int`: requests per rolling 24-hour window.
|
||||
|
||||
### `ProviderProfile`
|
||||
`type ProviderProfile struct`
|
||||
|
||||
`ProviderProfile` bundles a provider identifier with the default quota table for that provider.
|
||||
|
||||
- `Provider Provider`: the provider that owns the profile.
|
||||
- `Models map[string]ModelQuota`: built-in quotas keyed by model name.
|
||||
|
||||
### `Config`
|
||||
`type Config struct`
|
||||
|
||||
`Config` controls `RateLimiter` initialisation, backend selection, and default quotas.
|
||||
|
||||
- `FilePath string`: overrides the default persistence path. When empty, `NewWithConfig` resolves a default path under `~/.core`, using `ratelimits.yaml` for the YAML backend and `ratelimits.db` for the SQLite backend.
|
||||
- `Backend string`: selects the persistence backend. `NewWithConfig` accepts `""` or `"yaml"` for YAML and `"sqlite"` for SQLite. `NewWithSQLiteConfig` ignores this field and always uses SQLite.
|
||||
- `Quotas map[string]ModelQuota`: explicit per-model quotas. These are merged on top of any provider defaults loaded from `Providers`.
|
||||
- `Providers []Provider`: provider profiles to load from `DefaultProfiles`. If both `Providers` and `Quotas` are empty, Gemini defaults are used.
|
||||
|
||||
### `TokenEntry`
|
||||
`type TokenEntry struct`
|
||||
|
||||
`TokenEntry` records a single token-usage event.
|
||||
|
||||
- `Time time.Time`: when the token event was recorded.
|
||||
- `Count int`: how many tokens were counted for that event.
|
||||
|
||||
### `UsageStats`
|
||||
`type UsageStats struct`
|
||||
|
||||
`UsageStats` stores the in-memory usage history for one model.
|
||||
|
||||
- `Requests []time.Time`: request timestamps inside the sliding one-minute window.
|
||||
- `Tokens []TokenEntry`: token usage entries inside the sliding one-minute window.
|
||||
- `DayStart time.Time`: the start of the current rolling 24-hour window.
|
||||
- `DayCount int`: the number of requests recorded in the current rolling 24-hour window.
|
||||
|
||||
### `RateLimiter`
|
||||
`type RateLimiter struct`
|
||||
|
||||
`RateLimiter` is the package’s main concurrency-safe limiter. It stores quotas, tracks usage state per model, supports YAML or SQLite persistence, and prunes expired state as part of normal operations.
|
||||
|
||||
- `Quotas map[string]ModelQuota`: configured per-model limits. If a model has no quota entry, `CanSend` allows it.
|
||||
- `State map[string]*UsageStats`: tracked usage windows keyed by model name.
|
||||
|
||||
### `ModelStats`
|
||||
`type ModelStats struct`
|
||||
|
||||
`ModelStats` is the read-only snapshot returned by `Stats`, `AllStats`, and `Iter`.
|
||||
|
||||
- `RPM int`: current requests counted in the one-minute window.
|
||||
- `MaxRPM int`: configured requests-per-minute limit.
|
||||
- `TPM int`: current tokens counted in the one-minute window.
|
||||
- `MaxTPM int`: configured tokens-per-minute limit.
|
||||
- `RPD int`: current requests counted in the rolling 24-hour window.
|
||||
- `MaxRPD int`: configured requests-per-day limit.
|
||||
- `DayStart time.Time`: start of the current rolling 24-hour window. This is zero if the model has no recorded state.
|
||||
|
||||
### `DecisionCode`
|
||||
`type DecisionCode string`
|
||||
|
||||
`DecisionCode` enumerates machine-readable allow/deny codes returned by `Decide`. Defined values: `ok`, `unknown_model`, `unlimited`, `invalid_tokens`, `rpd_exceeded`, `rpm_exceeded`, and `tpm_exceeded`.
|
||||
|
||||
### `Decision`
|
||||
`type Decision struct`
|
||||
|
||||
`Decision` bundles the outcome from `Decide`, including whether the request is allowed, a `DecisionCode`, a human-readable `Reason`, an optional `RetryAfter` duration when throttled, and a `ModelStats` snapshot at the time of evaluation.
|
||||
|
||||
## Functions
|
||||
|
||||
### `DefaultProfiles() map[Provider]ProviderProfile`
|
||||
Returns a fresh map of built-in quota profiles for the supported providers. The returned map currently contains Gemini, OpenAI, Anthropic, and Local profiles. Because a new map is built on each call, callers can modify the result without mutating shared package state.
|
||||
|
||||
### `New() (*RateLimiter, error)`
|
||||
Creates a new YAML-backed `RateLimiter` with Gemini defaults. This is equivalent to calling `NewWithConfig(Config{Providers: []Provider{ProviderGemini}})`. It initialises in-memory state only; it does not automatically restore persisted data, so callers that want previous state must call `Load()`.
|
||||
|
||||
### `NewWithConfig(cfg Config) (*RateLimiter, error)`
|
||||
Creates a `RateLimiter` from explicit configuration. If `cfg.Backend` is empty it uses the YAML backend for backward compatibility. If both `cfg.Providers` and `cfg.Quotas` are empty, Gemini defaults are loaded. When `cfg.FilePath` is empty, the constructor resolves a default path under `~/.core`; for the implicit SQLite path it also ensures the parent directory exists. Like `New`, it does not call `Load()` automatically.
|
||||
|
||||
### `func (rl *RateLimiter) SetQuota(model string, quota ModelQuota)`
|
||||
Adds or replaces the quota for `model` in memory. This change affects later `CanSend`, `Stats`, and related calls immediately, but it is not persisted until `Persist()` is called.
|
||||
|
||||
### `func (rl *RateLimiter) AddProvider(provider Provider)`
|
||||
Loads the built-in quota profile for `provider` and copies its model quotas into `rl.Quotas`. Any existing quota entries for matching model names are overwritten. Unknown provider values are ignored.
|
||||
|
||||
### `func (rl *RateLimiter) Load() error`
|
||||
Loads persisted state into the limiter. For the YAML backend, it reads the configured file and unmarshals the stored quotas and state; a missing file is treated as an empty state and returns `nil`. For the SQLite backend, it loads persisted quotas and usage state from the database. If the database has stored quotas, those quotas replace the in-memory configuration; if no stored quotas exist, the current in-memory quotas are retained. In both cases, the loaded usage state replaces the current in-memory state.
|
||||
|
||||
### `func (rl *RateLimiter) Persist() error`
|
||||
Writes the current quotas and usage state to the configured backend. The method clones the in-memory snapshot while holding the lock, then performs I/O after releasing it. YAML persistence serialises the quota and state maps into the state file. SQLite persistence writes a full snapshot transactionally so quotas and usage move together.
|
||||
|
||||
### `func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func()`
|
||||
Starts a background goroutine that prunes expired entries from every tracked model on the supplied interval and returns a stop function. If `interval <= 0`, it returns a no-op stop function and does not start a goroutine.
|
||||
|
||||
### `func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool`
|
||||
Reports whether a request for `model` can be sent without violating the configured limits. Negative token estimates are rejected. Models with no configured quota are allowed. If all three limits for a known model are `0`, the model is treated as unlimited. Before evaluating the request, the limiter prunes entries older than one minute and resets the rolling daily counter when its 24-hour window has elapsed. The method then checks requests-per-day, requests-per-minute, and tokens-per-minute against the estimated token count.
|
||||
|
||||
### `func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision`
|
||||
Returns a structured allow/deny decision for the estimated request. The result includes a `DecisionCode`, a human-readable `Reason`, optional `RetryAfter` guidance when throttled, and a `ModelStats` snapshot. It prunes expired state, initialises empty state for configured models, but does not record usage.
|
||||
|
||||
### `func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)`
|
||||
Records a successful request for `model`. The limiter prunes stale entries first, creates state for the model if needed, appends the current timestamp to the request window, appends a token entry containing the combined prompt and output token count, and increments the rolling daily counter. Negative token values are ignored by the internal token summation logic rather than reducing the recorded total.
|
||||
|
||||
### `func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error`
|
||||
Blocks until `Decide(model, tokens)` allows the request or `ctx` is cancelled. The method uses the `RetryAfter` hint from `Decide` to sleep between checks, falling back to one-second polling when no hint is available. If `tokens` is negative, it returns an error immediately.
|
||||
|
||||
### `func (rl *RateLimiter) Reset(model string)`
|
||||
Clears usage state without changing quotas. If `model` is empty, it drops all tracked state. Otherwise it removes state only for the named model.
|
||||
|
||||
### `func (rl *RateLimiter) Models() iter.Seq[string]`
|
||||
Returns a sorted iterator of all model names currently known to the limiter. The result is the union of model names present in `rl.Quotas` and `rl.State`, so it includes models that only have stored state as well as models that only have configured quotas.
|
||||
|
||||
### `func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats]`
|
||||
Returns a sorted iterator of model names paired with their current `ModelStats` snapshots. Internally it builds the snapshot via `AllStats()` and yields entries in lexical model-name order.
|
||||
|
||||
### `func (rl *RateLimiter) Stats(model string) ModelStats`
|
||||
Returns the current snapshot for a single model after pruning expired entries. The result includes both current usage and configured maxima. If the model has no configured quota, the maximum fields are zero. If the model has no recorded state, the usage counters are zero and `DayStart` is the zero time.
|
||||
|
||||
### `func (rl *RateLimiter) AllStats() map[string]ModelStats`
|
||||
Returns a snapshot for every tracked model. The returned map includes model names found in either `rl.Quotas` or `rl.State`. Each model is pruned before its snapshot is computed, so expired one-minute entries are removed and stale daily windows are reset as part of the call.
|
||||
|
||||
### `NewWithSQLite(dbPath string) (*RateLimiter, error)`
|
||||
Creates a SQLite-backed `RateLimiter` with Gemini defaults and opens or creates the database at `dbPath`. Like the YAML constructors, it initialises in-memory configuration but does not automatically call `Load()`. Callers should `defer rl.Close()` when they are done with the limiter.
|
||||
|
||||
### `NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error)`
|
||||
Creates a SQLite-backed `RateLimiter` using `cfg` for provider and quota configuration. The `Backend` field in `cfg` is ignored because this constructor always uses SQLite. The database is opened or created at `dbPath`, and callers should `defer rl.Close()` to release the connection. Existing persisted data is not loaded until `Load()` is called.
|
||||
|
||||
### `func (rl *RateLimiter) Close() error`
|
||||
Releases resources held by the limiter. For YAML-backed limiters this is a no-op that returns `nil`. For SQLite-backed limiters it closes the underlying database connection.
|
||||
|
||||
### `MigrateYAMLToSQLite(yamlPath, sqlitePath string) error`
|
||||
Reads a YAML state file into a temporary `RateLimiter` and writes its quotas and usage state into a SQLite database. The SQLite database is created if it does not exist. The migration writes a complete snapshot, so any existing SQLite snapshot tables are replaced by the imported data.
|
||||
|
||||
### `CountTokens(ctx context.Context, apiKey, model, text string) (int, error)`
|
||||
Calls Google’s Gemini `countTokens` API for `model` and returns the `totalTokens` value from the response. The function uses `http.DefaultClient`, posts to the Generative Language API base URL, and sends the API key through the `x-goog-api-key` header. It validates that `model` is non-empty, truncates oversized response bodies when building error messages, and wraps transport, request-building, and decoding failures with package-scoped errors.
|
||||
71
sqlite.go
71
sqlite.go
|
|
@ -1,12 +1,11 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
|
|
@ -21,7 +20,7 @@ type sqliteStore struct {
|
|||
func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
if err != nil {
|
||||
return nil, core.E("ratelimit.newSQLiteStore", "open", err)
|
||||
return nil, coreerr.E("ratelimit.newSQLiteStore", "open", err)
|
||||
}
|
||||
|
||||
// Single connection for PRAGMA consistency.
|
||||
|
|
@ -29,11 +28,11 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
|||
|
||||
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||||
db.Close()
|
||||
return nil, core.E("ratelimit.newSQLiteStore", "WAL", err)
|
||||
return nil, coreerr.E("ratelimit.newSQLiteStore", "WAL", err)
|
||||
}
|
||||
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
||||
db.Close()
|
||||
return nil, core.E("ratelimit.newSQLiteStore", "busy_timeout", err)
|
||||
return nil, coreerr.E("ratelimit.newSQLiteStore", "busy_timeout", err)
|
||||
}
|
||||
|
||||
if err := createSchema(db); err != nil {
|
||||
|
|
@ -73,7 +72,7 @@ func createSchema(db *sql.DB) error {
|
|||
|
||||
for _, stmt := range stmts {
|
||||
if _, err := db.Exec(stmt); err != nil {
|
||||
return core.E("ratelimit.createSchema", "exec", err)
|
||||
return coreerr.E("ratelimit.createSchema", "exec", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -83,12 +82,12 @@ func createSchema(db *sql.DB) error {
|
|||
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveQuotas", "begin", err)
|
||||
return coreerr.E("ratelimit.saveQuotas", "begin", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||
return core.E("ratelimit.saveQuotas", "clear", err)
|
||||
return coreerr.E("ratelimit.saveQuotas", "clear", err)
|
||||
}
|
||||
|
||||
if err := insertQuotas(tx, quotas); err != nil {
|
||||
|
|
@ -102,7 +101,7 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
|||
func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
||||
rows, err := s.db.Query("SELECT model, max_rpm, max_tpm, max_rpd FROM quotas")
|
||||
if err != nil {
|
||||
return nil, core.E("ratelimit.loadQuotas", "query", err)
|
||||
return nil, coreerr.E("ratelimit.loadQuotas", "query", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
|
@ -111,12 +110,12 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
|||
var model string
|
||||
var q ModelQuota
|
||||
if err := rows.Scan(&model, &q.MaxRPM, &q.MaxTPM, &q.MaxRPD); err != nil {
|
||||
return nil, core.E("ratelimit.loadQuotas", "scan", err)
|
||||
return nil, coreerr.E("ratelimit.loadQuotas", "scan", err)
|
||||
}
|
||||
result[model] = q
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, core.E("ratelimit.loadQuotas", "rows", err)
|
||||
return nil, coreerr.E("ratelimit.loadQuotas", "rows", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
@ -125,7 +124,7 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
|||
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveSnapshot", "begin", err)
|
||||
return coreerr.E("ratelimit.saveSnapshot", "begin", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
|
|
@ -149,7 +148,7 @@ func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[strin
|
|||
func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveState", "begin", err)
|
||||
return coreerr.E("ratelimit.saveState", "begin", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
|
|
@ -167,17 +166,17 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
|||
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
|
||||
if includeQuotas {
|
||||
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||
return core.E("ratelimit.saveSnapshot", "clear quotas", err)
|
||||
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err)
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
|
||||
return core.E("ratelimit.saveState", "clear requests", err)
|
||||
return coreerr.E("ratelimit.saveState", "clear requests", err)
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM tokens"); err != nil {
|
||||
return core.E("ratelimit.saveState", "clear tokens", err)
|
||||
return coreerr.E("ratelimit.saveState", "clear tokens", err)
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
|
||||
return core.E("ratelimit.saveState", "clear daily", err)
|
||||
return coreerr.E("ratelimit.saveState", "clear daily", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -185,13 +184,13 @@ func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
|
|||
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
|
||||
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveQuotas", "prepare", err)
|
||||
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for model, q := range quotas {
|
||||
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
||||
return core.E("ratelimit.saveQuotas", core.Concat("exec ", model), err)
|
||||
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -200,19 +199,19 @@ func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
|
|||
func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
||||
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveState", "prepare requests", err)
|
||||
return coreerr.E("ratelimit.saveState", "prepare requests", err)
|
||||
}
|
||||
defer reqStmt.Close()
|
||||
|
||||
tokStmt, err := tx.Prepare("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)")
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveState", "prepare tokens", err)
|
||||
return coreerr.E("ratelimit.saveState", "prepare tokens", err)
|
||||
}
|
||||
defer tokStmt.Close()
|
||||
|
||||
dayStmt, err := tx.Prepare("INSERT INTO daily (model, day_start, day_count) VALUES (?, ?, ?)")
|
||||
if err != nil {
|
||||
return core.E("ratelimit.saveState", "prepare daily", err)
|
||||
return coreerr.E("ratelimit.saveState", "prepare daily", err)
|
||||
}
|
||||
defer dayStmt.Close()
|
||||
|
||||
|
|
@ -222,16 +221,16 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
|||
}
|
||||
for _, t := range stats.Requests {
|
||||
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
|
||||
return core.E("ratelimit.saveState", core.Concat("insert request ", model), err)
|
||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
|
||||
}
|
||||
}
|
||||
for _, te := range stats.Tokens {
|
||||
if _, err := tokStmt.Exec(model, te.Time.UnixNano(), te.Count); err != nil {
|
||||
return core.E("ratelimit.saveState", core.Concat("insert token ", model), err)
|
||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert token %s", model), err)
|
||||
}
|
||||
}
|
||||
if _, err := dayStmt.Exec(model, stats.DayStart.UnixNano(), stats.DayCount); err != nil {
|
||||
return core.E("ratelimit.saveState", core.Concat("insert daily ", model), err)
|
||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
|
@ -239,7 +238,7 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
|||
|
||||
func commitTx(tx *sql.Tx, scope string) error {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return core.E(scope, "commit", err)
|
||||
return coreerr.E(scope, "commit", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -251,7 +250,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
// Load daily counters first (these define which models have state).
|
||||
rows, err := s.db.Query("SELECT model, day_start, day_count FROM daily")
|
||||
if err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "query daily", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "query daily", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
|
|
@ -260,7 +259,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
var dayStartNano int64
|
||||
var dayCount int
|
||||
if err := rows.Scan(&model, &dayStartNano, &dayCount); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "scan daily", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "scan daily", err)
|
||||
}
|
||||
result[model] = &UsageStats{
|
||||
DayStart: time.Unix(0, dayStartNano),
|
||||
|
|
@ -268,13 +267,13 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "daily rows", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "daily rows", err)
|
||||
}
|
||||
|
||||
// Load requests.
|
||||
reqRows, err := s.db.Query("SELECT model, ts FROM requests ORDER BY ts")
|
||||
if err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "query requests", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "query requests", err)
|
||||
}
|
||||
defer reqRows.Close()
|
||||
|
||||
|
|
@ -282,7 +281,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
var model string
|
||||
var tsNano int64
|
||||
if err := reqRows.Scan(&model, &tsNano); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "scan requests", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "scan requests", err)
|
||||
}
|
||||
if _, ok := result[model]; !ok {
|
||||
result[model] = &UsageStats{}
|
||||
|
|
@ -290,13 +289,13 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
result[model].Requests = append(result[model].Requests, time.Unix(0, tsNano))
|
||||
}
|
||||
if err := reqRows.Err(); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "request rows", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "request rows", err)
|
||||
}
|
||||
|
||||
// Load tokens.
|
||||
tokRows, err := s.db.Query("SELECT model, ts, count FROM tokens ORDER BY ts")
|
||||
if err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "query tokens", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "query tokens", err)
|
||||
}
|
||||
defer tokRows.Close()
|
||||
|
||||
|
|
@ -305,7 +304,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
var tsNano int64
|
||||
var count int
|
||||
if err := tokRows.Scan(&model, &tsNano, &count); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "scan tokens", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "scan tokens", err)
|
||||
}
|
||||
if _, ok := result[model]; !ok {
|
||||
result[model] = &UsageStats{}
|
||||
|
|
@ -316,7 +315,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
|||
})
|
||||
}
|
||||
if err := tokRows.Err(); err != nil {
|
||||
return nil, core.E("ratelimit.loadState", "token rows", err)
|
||||
return nil, coreerr.E("ratelimit.loadState", "token rows", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
|
|
|
|||
159
sqlite_test.go
159
sqlite_test.go
|
|
@ -1,8 +1,8 @@
|
|||
// SPDX-License-Identifier: EUPL-1.2
|
||||
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -14,17 +14,18 @@ import (
|
|||
|
||||
// --- Phase 2: SQLite basic tests ---
|
||||
|
||||
func TestSQLite_NewSQLiteStore_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "test.db")
|
||||
func TestNewSQLiteStore_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
||||
// Verify the database file was created.
|
||||
assert.True(t, pathExists(dbPath), "database file should exist")
|
||||
_, statErr := os.Stat(dbPath)
|
||||
assert.NoError(t, statErr, "database file should exist")
|
||||
}
|
||||
|
||||
func TestSQLite_NewSQLiteStore_Bad(t *testing.T) {
|
||||
func TestNewSQLiteStore_Bad(t *testing.T) {
|
||||
t.Run("invalid path returns error", func(t *testing.T) {
|
||||
// Path inside a non-existent directory with no parent.
|
||||
_, err := newSQLiteStore("/nonexistent/deep/nested/dir/test.db")
|
||||
|
|
@ -32,8 +33,8 @@ func TestSQLite_NewSQLiteStore_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLite_QuotasRoundTrip_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "quotas.db")
|
||||
func TestSQLiteQuotasRoundTrip_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "quotas.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -59,8 +60,8 @@ func TestSQLite_QuotasRoundTrip_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSQLite_QuotasOverwrite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "overwrite.db")
|
||||
func TestSQLiteQuotasUpsert_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "upsert.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -70,7 +71,7 @@ func TestSQLite_QuotasOverwrite_Good(t *testing.T) {
|
|||
"model-a": {MaxRPM: 100, MaxTPM: 50000, MaxRPD: 1000},
|
||||
}))
|
||||
|
||||
// Save a second snapshot with updated values.
|
||||
// Upsert with updated values.
|
||||
require.NoError(t, store.saveQuotas(map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777},
|
||||
}))
|
||||
|
|
@ -84,8 +85,8 @@ func TestSQLite_QuotasOverwrite_Good(t *testing.T) {
|
|||
assert.Equal(t, 777, q.MaxRPD, "should have updated RPD")
|
||||
}
|
||||
|
||||
func TestSQLite_StateRoundTrip_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "state.db")
|
||||
func TestSQLiteStateRoundTrip_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "state.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -143,8 +144,8 @@ func TestSQLite_StateRoundTrip_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSQLite_StateOverwrite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "overwrite.db")
|
||||
func TestSQLiteStateOverwrite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "overwrite.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -181,8 +182,8 @@ func TestSQLite_StateOverwrite_Good(t *testing.T) {
|
|||
assert.Len(t, b.Requests, 1)
|
||||
}
|
||||
|
||||
func TestSQLite_EmptyState_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "empty.db")
|
||||
func TestSQLiteEmptyState_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "empty.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -197,8 +198,8 @@ func TestSQLite_EmptyState_Good(t *testing.T) {
|
|||
assert.Empty(t, state, "should return empty state from fresh DB")
|
||||
}
|
||||
|
||||
func TestSQLite_Close_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "close.db")
|
||||
func TestSQLiteClose_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "close.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -207,8 +208,8 @@ func TestSQLite_Close_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: SQLite integration tests ---
|
||||
|
||||
func TestSQLite_NewWithSQLite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "limiter.db")
|
||||
func TestNewWithSQLite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "limiter.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
|
@ -221,8 +222,8 @@ func TestSQLite_NewWithSQLite_Good(t *testing.T) {
|
|||
assert.NotNil(t, rl.sqlite, "SQLite store should be initialised")
|
||||
}
|
||||
|
||||
func TestSQLite_NewWithSQLiteConfig_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "config.db")
|
||||
func TestNewWithSQLiteConfig_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "config.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderAnthropic},
|
||||
Quotas: map[string]ModelQuota{
|
||||
|
|
@ -242,8 +243,8 @@ func TestSQLite_NewWithSQLiteConfig_Good(t *testing.T) {
|
|||
assert.False(t, hasGemini, "should not have Gemini models")
|
||||
}
|
||||
|
||||
func TestSQLite_PersistAndLoad_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist.db")
|
||||
func TestSQLitePersistAndLoad_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -271,8 +272,8 @@ func TestSQLite_PersistAndLoad_Good(t *testing.T) {
|
|||
assert.Equal(t, 500, stats.MaxRPD)
|
||||
}
|
||||
|
||||
func TestSQLite_PersistMultipleModels_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "multi.db")
|
||||
func TestSQLitePersistMultipleModels_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "multi.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
|
|
@ -301,8 +302,8 @@ func TestSQLite_PersistMultipleModels_Good(t *testing.T) {
|
|||
assert.Equal(t, 400, claude.TPM)
|
||||
}
|
||||
|
||||
func TestSQLite_RecordUsageThenPersistReload_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "record.db")
|
||||
func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "record.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -339,7 +340,7 @@ func TestSQLite_RecordUsageThenPersistReload_Good(t *testing.T) {
|
|||
assert.Equal(t, 1000, stats2.TPM, "TPM should survive reload")
|
||||
}
|
||||
|
||||
func TestSQLite_CloseNoOp_Good(t *testing.T) {
|
||||
func TestSQLiteClose_Good_NoOp(t *testing.T) {
|
||||
// Close on YAML-backed limiter is a no-op.
|
||||
rl := newTestLimiter(t)
|
||||
assert.NoError(t, rl.Close(), "Close on YAML limiter should be no-op")
|
||||
|
|
@ -347,8 +348,8 @@ func TestSQLite_CloseNoOp_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: Concurrent SQLite ---
|
||||
|
||||
func TestSQLite_Concurrent_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "concurrent.db")
|
||||
func TestSQLiteConcurrent_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "concurrent.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
|
@ -397,10 +398,10 @@ func TestSQLite_Concurrent_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: YAML backward compatibility ---
|
||||
|
||||
func TestSQLite_YAMLBackwardCompat_Good(t *testing.T) {
|
||||
func TestYAMLBackwardCompat_Good(t *testing.T) {
|
||||
// Verify that the default YAML backend still works after SQLite additions.
|
||||
tmpDir := t.TempDir()
|
||||
path := testPath(tmpDir, "compat.yaml")
|
||||
path := filepath.Join(tmpDir, "compat.yaml")
|
||||
|
||||
rl1, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -424,18 +425,18 @@ func TestSQLite_YAMLBackwardCompat_Good(t *testing.T) {
|
|||
assert.Equal(t, 200, stats.TPM)
|
||||
}
|
||||
|
||||
func TestSQLite_ConfigBackendDefault_Good(t *testing.T) {
|
||||
func TestConfigBackendDefault_Good(t *testing.T) {
|
||||
// Empty Backend string should default to YAML behaviour.
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: testPath(t.TempDir(), "default.yaml"),
|
||||
FilePath: filepath.Join(t.TempDir(), "default.yaml"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
|
||||
}
|
||||
|
||||
func TestSQLite_ConfigBackendSQLite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "config-backend.db")
|
||||
func TestConfigBackendSQLite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
|
||||
rl, err := NewWithConfig(Config{
|
||||
Backend: backendSQLite,
|
||||
FilePath: dbPath,
|
||||
|
|
@ -450,10 +451,11 @@ func TestSQLite_ConfigBackendSQLite_Good(t *testing.T) {
|
|||
rl.RecordUsage("backend-model", 10, 10)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
assert.True(t, pathExists(dbPath), "sqlite backend should persist to the configured DB path")
|
||||
_, statErr := os.Stat(dbPath)
|
||||
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
|
||||
}
|
||||
|
||||
func TestSQLite_ConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", "")
|
||||
|
|
@ -468,15 +470,16 @@ func TestSQLite_ConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
|||
require.NotNil(t, rl.sqlite)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
assert.True(t, pathExists(testPath(home, defaultStateDirName, defaultSQLiteStateFile)), "sqlite backend should use the default home DB path")
|
||||
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
|
||||
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
|
||||
}
|
||||
|
||||
// --- Phase 2: MigrateYAMLToSQLite ---
|
||||
|
||||
func TestSQLite_MigrateYAMLToSQLite_Good(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "state.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrated.db")
|
||||
yamlPath := filepath.Join(tmpDir, "state.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrated.db")
|
||||
|
||||
// Create a YAML-backed limiter with state.
|
||||
rl, err := New()
|
||||
|
|
@ -512,26 +515,26 @@ func TestSQLite_MigrateYAMLToSQLite_Good(t *testing.T) {
|
|||
assert.Equal(t, 2, stats.RPD, "should have 2 daily requests")
|
||||
}
|
||||
|
||||
func TestSQLite_MigrateYAMLToSQLite_Bad(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
|
||||
t.Run("non-existent YAML file", func(t *testing.T) {
|
||||
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", testPath(t.TempDir(), "out.db"))
|
||||
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", filepath.Join(t.TempDir(), "out.db"))
|
||||
assert.Error(t, err, "should fail with non-existent YAML file")
|
||||
})
|
||||
|
||||
t.Run("corrupt YAML file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "corrupt.yaml")
|
||||
writeTestFile(t, yamlPath, "{{{{not yaml!")
|
||||
yamlPath := filepath.Join(tmpDir, "corrupt.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("{{{{not yaml!"), 0644))
|
||||
|
||||
err := MigrateYAMLToSQLite(yamlPath, testPath(tmpDir, "out.db"))
|
||||
err := MigrateYAMLToSQLite(yamlPath, filepath.Join(tmpDir, "out.db"))
|
||||
assert.Error(t, err, "should fail with corrupt YAML")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "atomic.yaml")
|
||||
sqlitePath := testPath(tmpDir, "atomic.db")
|
||||
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "atomic.db")
|
||||
now := time.Now().UTC()
|
||||
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -570,7 +573,7 @@ func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
|||
}
|
||||
data, err := yaml.Marshal(migrated)
|
||||
require.NoError(t, err)
|
||||
writeTestFile(t, yamlPath, string(data))
|
||||
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
|
||||
|
||||
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
|
||||
require.Error(t, err)
|
||||
|
|
@ -591,10 +594,10 @@ func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
|||
assert.NotContains(t, state, "new-model")
|
||||
}
|
||||
|
||||
func TestSQLite_MigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "full.yaml")
|
||||
sqlitePath := testPath(tmpDir, "full.db")
|
||||
yamlPath := filepath.Join(tmpDir, "full.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "full.db")
|
||||
|
||||
// Create a full YAML state with all Gemini models.
|
||||
rl, err := New()
|
||||
|
|
@ -623,12 +626,12 @@ func TestSQLite_MigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: Corrupt DB recovery ---
|
||||
|
||||
func TestSQLite_CorruptDB_Ugly(t *testing.T) {
|
||||
func TestSQLiteCorruptDB_Ugly(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := testPath(tmpDir, "corrupt.db")
|
||||
dbPath := filepath.Join(tmpDir, "corrupt.db")
|
||||
|
||||
// Write garbage to the DB file.
|
||||
writeTestFile(t, dbPath, "THIS IS NOT A SQLITE DATABASE")
|
||||
require.NoError(t, os.WriteFile(dbPath, []byte("THIS IS NOT A SQLITE DATABASE"), 0644))
|
||||
|
||||
// Opening a corrupt DB may succeed (sqlite is lazy about validation),
|
||||
// but operations on it should fail gracefully.
|
||||
|
|
@ -645,9 +648,9 @@ func TestSQLite_CorruptDB_Ugly(t *testing.T) {
|
|||
assert.Error(t, err, "loading from corrupt DB should return an error")
|
||||
}
|
||||
|
||||
func TestSQLite_TruncatedDB_Ugly(t *testing.T) {
|
||||
func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := testPath(tmpDir, "truncated.db")
|
||||
dbPath := filepath.Join(tmpDir, "truncated.db")
|
||||
|
||||
// Create a valid DB first.
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
|
|
@ -658,7 +661,11 @@ func TestSQLite_TruncatedDB_Ugly(t *testing.T) {
|
|||
require.NoError(t, store.close())
|
||||
|
||||
// Truncate the file to simulate corruption.
|
||||
overwriteTestFile(t, dbPath, "TRUNC")
|
||||
f, err := os.OpenFile(dbPath, os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
require.NoError(t, err)
|
||||
_, err = f.Write([]byte("TRUNC"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
// Opening should either fail or operations should fail.
|
||||
store2, err := newSQLiteStore(dbPath)
|
||||
|
|
@ -672,9 +679,9 @@ func TestSQLite_TruncatedDB_Ugly(t *testing.T) {
|
|||
assert.Error(t, err, "loading from truncated DB should return an error")
|
||||
}
|
||||
|
||||
func TestSQLite_EmptyModelState_Good(t *testing.T) {
|
||||
func TestSQLiteEmptyModelState_Good(t *testing.T) {
|
||||
// State with no requests or tokens but with a daily counter.
|
||||
dbPath := testPath(t.TempDir(), "empty-state.db")
|
||||
dbPath := filepath.Join(t.TempDir(), "empty-state.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -701,8 +708,8 @@ func TestSQLite_EmptyModelState_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: End-to-end with persist cycle ---
|
||||
|
||||
func TestSQLite_EndToEnd_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "e2e.db")
|
||||
func TestSQLiteEndToEnd_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "e2e.db")
|
||||
|
||||
// Session 1: Create limiter, record usage, persist.
|
||||
rl1, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
|
|
@ -745,8 +752,8 @@ func TestSQLite_EndToEnd_Good(t *testing.T) {
|
|||
assert.Equal(t, 5, custom.MaxRPM)
|
||||
}
|
||||
|
||||
func TestSQLite_LoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "replace.db")
|
||||
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "replace.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
|
|
@ -781,8 +788,8 @@ func TestSQLite_LoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
|||
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
|
||||
}
|
||||
|
||||
func TestSQLite_PersistAtomic_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist-atomic.db")
|
||||
func TestSQLitePersistAtomic_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
|
|
@ -820,7 +827,7 @@ func TestSQLite_PersistAtomic_Good(t *testing.T) {
|
|||
// --- Phase 2: Benchmark ---
|
||||
|
||||
func BenchmarkSQLitePersist(b *testing.B) {
|
||||
dbPath := testPath(b.TempDir(), "bench.db")
|
||||
dbPath := filepath.Join(b.TempDir(), "bench.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
@ -845,7 +852,7 @@ func BenchmarkSQLitePersist(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkSQLiteLoad(b *testing.B) {
|
||||
dbPath := testPath(b.TempDir(), "bench-load.db")
|
||||
dbPath := filepath.Join(b.TempDir(), "bench-load.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
@ -876,10 +883,10 @@ func BenchmarkSQLiteLoad(b *testing.B) {
|
|||
|
||||
// TestMigrateYAMLToSQLiteWithFullState tests migration of a realistic YAML
|
||||
// file that contains the full serialised RateLimiter struct.
|
||||
func TestSQLite_MigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
||||
func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := testPath(tmpDir, "realistic.yaml")
|
||||
sqlitePath := testPath(tmpDir, "realistic.db")
|
||||
yamlPath := filepath.Join(tmpDir, "realistic.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "realistic.db")
|
||||
|
||||
now := time.Now()
|
||||
|
||||
|
|
@ -912,7 +919,7 @@ func TestSQLite_MigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
|||
|
||||
data, err := yaml.Marshal(rl)
|
||||
require.NoError(t, err)
|
||||
writeTestFile(t, yamlPath, string(data))
|
||||
require.NoError(t, os.WriteFile(yamlPath, data, 0644))
|
||||
|
||||
// Migrate.
|
||||
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue