go-ratelimit/ratelimit.go
Virgil d1c90b937d fix(ratelimit): harden audit edge cases
Co-Authored-By: Virgil <virgil@lethean.io>
2026-03-23 07:26:15 +00:00

795 lines
20 KiB
Go

package ratelimit
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"iter"
"maps"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"sync"
"time"
coreio "dappco.re/go/core/io"
coreerr "dappco.re/go/core/log"
"gopkg.in/yaml.v3"
)
// Provider identifies an LLM provider for quota profiles.
type Provider string
const (
// ProviderGemini is Google's Gemini family (default).
ProviderGemini Provider = "gemini"
// ProviderOpenAI is OpenAI's GPT/o-series family.
ProviderOpenAI Provider = "openai"
// ProviderAnthropic is Anthropic's Claude family.
ProviderAnthropic Provider = "anthropic"
// ProviderLocal is for local inference (Ollama, MLX, llama.cpp).
ProviderLocal Provider = "local"
)
const (
defaultStateDirName = ".core"
defaultYAMLStateFile = "ratelimits.yaml"
defaultSQLiteStateFile = "ratelimits.db"
backendYAML = "yaml"
backendSQLite = "sqlite"
countTokensErrorBodyLimit = 8 * 1024
countTokensSuccessBodyLimit = 1 * 1024 * 1024
)
// ModelQuota defines the rate limits for a specific model.
type ModelQuota struct {
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited)
MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited)
}
// ProviderProfile bundles model quotas for a provider.
type ProviderProfile struct {
Provider Provider `yaml:"provider"`
Models map[string]ModelQuota `yaml:"models"`
}
// Config controls RateLimiter initialisation.
type Config struct {
// FilePath overrides the default state file location.
// If empty, defaults to ~/.core/ratelimits.yaml.
FilePath string `yaml:"file_path,omitempty"`
// Backend selects the persistence backend: "yaml" (default) or "sqlite".
// An empty string is treated as "yaml" for backward compatibility.
Backend string `yaml:"backend,omitempty"`
// Quotas sets per-model rate limits directly.
// These are merged on top of any provider profile defaults.
Quotas map[string]ModelQuota `yaml:"quotas,omitempty"`
// Providers lists provider profiles to load.
// If empty and Quotas is also empty, Gemini defaults are used.
Providers []Provider `yaml:"providers,omitempty"`
}
// TokenEntry records a token usage event.
type TokenEntry struct {
Time time.Time `yaml:"time"`
Count int `yaml:"count"`
}
// UsageStats tracks usage history for a model.
type UsageStats struct {
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
DayStart time.Time `yaml:"day_start"`
DayCount int `yaml:"day_count"`
}
// RateLimiter manages rate limits across multiple models.
type RateLimiter struct {
mu sync.RWMutex
Quotas map[string]ModelQuota `yaml:"quotas"`
State map[string]*UsageStats `yaml:"state"`
filePath string
sqlite *sqliteStore // non-nil when backend is "sqlite"
}
// DefaultProfiles returns pre-configured quota profiles for each provider.
// Values are based on published rate limits as of Feb 2026.
func DefaultProfiles() map[Provider]ProviderProfile {
return map[Provider]ProviderProfile{
ProviderGemini: {
Provider: ProviderGemini,
Models: map[string]ModelQuota{
"gemini-3-pro-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000},
"gemini-3-flash-preview": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000},
"gemini-2.5-pro": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000},
"gemini-2.0-flash": {MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0}, // Unlimited RPD
"gemini-2.0-flash-lite": {MaxRPM: 0, MaxTPM: 0, MaxRPD: 0}, // Unlimited
},
},
ProviderOpenAI: {
Provider: ProviderOpenAI,
Models: map[string]ModelQuota{
"gpt-4o": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0},
"gpt-4o-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0},
"gpt-4-turbo": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0},
"o1": {MaxRPM: 500, MaxTPM: 30000, MaxRPD: 0},
"o1-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0},
"o3-mini": {MaxRPM: 500, MaxTPM: 200000, MaxRPD: 0},
},
},
ProviderAnthropic: {
Provider: ProviderAnthropic,
Models: map[string]ModelQuota{
"claude-opus-4": {MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0},
"claude-sonnet-4": {MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0},
"claude-haiku-3.5": {MaxRPM: 50, MaxTPM: 50000, MaxRPD: 0},
},
},
ProviderLocal: {
Provider: ProviderLocal,
Models: map[string]ModelQuota{
// Local inference has no external rate limits by default.
// Users can override per-model if their hardware requires throttling.
},
},
}
}
// New creates a new RateLimiter with Gemini defaults.
// This preserves backward compatibility -- existing callers are unaffected.
func New() (*RateLimiter, error) {
return NewWithConfig(Config{
Providers: []Provider{ProviderGemini},
})
}
// NewWithConfig creates a RateLimiter from explicit configuration.
// If no providers or quotas are specified, Gemini defaults are used.
func NewWithConfig(cfg Config) (*RateLimiter, error) {
backend, err := normaliseBackend(cfg.Backend)
if err != nil {
return nil, err
}
filePath := cfg.FilePath
if filePath == "" {
filePath, err = defaultStatePath(backend)
if err != nil {
return nil, err
}
}
if backend == backendSQLite {
if cfg.FilePath == "" {
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
}
}
return NewWithSQLiteConfig(filePath, cfg)
}
rl := newConfiguredRateLimiter(cfg)
rl.filePath = filePath
return rl, nil
}
// SetQuota sets or updates the quota for a specific model at runtime.
func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.Quotas[model] = quota
}
// AddProvider loads all default quotas for a provider.
// Existing quotas for models in the profile are overwritten.
func (rl *RateLimiter) AddProvider(provider Provider) {
rl.mu.Lock()
defer rl.mu.Unlock()
profiles := DefaultProfiles()
if profile, ok := profiles[provider]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
// Load reads the state from disk (YAML) or database (SQLite).
func (rl *RateLimiter) Load() error {
rl.mu.Lock()
defer rl.mu.Unlock()
if rl.sqlite != nil {
return rl.loadSQLite()
}
content, err := coreio.Local.Read(rl.filePath)
if os.IsNotExist(err) {
return nil
}
if err != nil {
return err
}
return yaml.Unmarshal([]byte(content), rl)
}
// loadSQLite reads quotas and state from the SQLite backend.
// Caller must hold the lock.
func (rl *RateLimiter) loadSQLite() error {
quotas, err := rl.sqlite.loadQuotas()
if err != nil {
return err
}
// Persisted quotas are authoritative when present; otherwise keep config defaults.
if len(quotas) > 0 {
rl.Quotas = maps.Clone(quotas)
}
state, err := rl.sqlite.loadState()
if err != nil {
return err
}
// Replace in-memory state with the persisted snapshot.
rl.State = state
return nil
}
// Persist writes a snapshot of the state to disk (YAML) or database (SQLite).
// It clones the state under a lock and performs I/O without blocking other callers.
func (rl *RateLimiter) Persist() error {
rl.mu.Lock()
quotas := maps.Clone(rl.Quotas)
state := make(map[string]*UsageStats, len(rl.State))
for k, v := range rl.State {
if v == nil {
continue
}
state[k] = &UsageStats{
Requests: slices.Clone(v.Requests),
Tokens: slices.Clone(v.Tokens),
DayStart: v.DayStart,
DayCount: v.DayCount,
}
}
sqlite := rl.sqlite
filePath := rl.filePath
rl.mu.Unlock()
if sqlite != nil {
if err := sqlite.saveSnapshot(quotas, state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
}
return nil
}
// For YAML, we marshal the entire RateLimiter, but since we want to avoid
// holding the lock during marshal, we marshal a temporary struct.
data, err := yaml.Marshal(struct {
Quotas map[string]ModelQuota `yaml:"quotas"`
State map[string]*UsageStats `yaml:"state"`
}{
Quotas: quotas,
State: state,
})
if err != nil {
return coreerr.E("ratelimit.Persist", "marshal", err)
}
if err := coreio.Local.Write(filePath, string(data)); err != nil {
return coreerr.E("ratelimit.Persist", "write", err)
}
return nil
}
// prune removes entries older than the sliding window (1 minute) and removes
// empty state for models that haven't been used recently.
// Caller must hold lock.
func (rl *RateLimiter) prune(model string) {
stats, ok := rl.State[model]
if !ok {
return
}
if stats == nil {
delete(rl.State, model)
return
}
now := time.Now()
window := now.Add(-1 * time.Minute)
// Prune requests
stats.Requests = slices.DeleteFunc(stats.Requests, func(t time.Time) bool {
return t.Before(window)
})
// Prune tokens
stats.Tokens = slices.DeleteFunc(stats.Tokens, func(t TokenEntry) bool {
return t.Time.Before(window)
})
// Reset daily counter if day has passed
if now.Sub(stats.DayStart) >= 24*time.Hour {
stats.DayStart = now
stats.DayCount = 0
}
// If everything is empty and it's been more than a minute since last activity,
// delete the model state entirely to prevent memory leaks.
if len(stats.Requests) == 0 && len(stats.Tokens) == 0 {
// We could use a more sophisticated TTL here, but for now just cleanup empty ones.
// Note: we don't delete if DayCount > 0 and it's still the same day.
if stats.DayCount == 0 {
delete(rl.State, model)
}
}
}
// BackgroundPrune starts a goroutine that periodically prunes all model states.
// It returns a function to stop the pruner.
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
if interval <= 0 {
return func() {}
}
ctx, cancel := context.WithCancel(context.Background())
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
rl.mu.Lock()
for m := range rl.State {
rl.prune(m)
}
rl.mu.Unlock()
}
}
}()
return cancel
}
// CanSend checks if a request can be sent without violating limits.
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
if estimatedTokens < 0 {
return false
}
rl.mu.Lock()
defer rl.mu.Unlock()
quota, ok := rl.Quotas[model]
if !ok {
return true // Unknown models are allowed
}
// Unlimited check
if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 {
return true
}
rl.prune(model)
stats, ok := rl.State[model]
if !ok {
stats = &UsageStats{DayStart: time.Now()}
rl.State[model] = stats
}
// Check RPD
if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD {
return false
}
// Check RPM
if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM {
return false
}
// Check TPM
if quota.MaxTPM > 0 {
currentTokens := totalTokenCount(stats.Tokens)
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
return false
}
}
return true
}
// RecordUsage records a successful API call.
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.prune(model) // Prune before recording to ensure we're not exceeding limits immediately after
stats, ok := rl.State[model]
if !ok {
stats = &UsageStats{DayStart: time.Now()}
rl.State[model] = stats
}
now := time.Now()
totalTokens := safeTokenSum(promptTokens, outputTokens)
stats.Requests = append(stats.Requests, now)
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens})
stats.DayCount++
}
// WaitForCapacity blocks until capacity is available or context is cancelled.
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
if tokens < 0 {
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
}
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
if rl.CanSend(model, tokens) {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
// check again
}
}
}
// Reset clears stats for a model (or all if model is empty).
func (rl *RateLimiter) Reset(model string) {
rl.mu.Lock()
defer rl.mu.Unlock()
if model == "" {
rl.State = make(map[string]*UsageStats)
} else {
delete(rl.State, model)
}
}
// ModelStats represents a snapshot of usage.
type ModelStats struct {
RPM int
MaxRPM int
TPM int
MaxTPM int
RPD int
MaxRPD int
DayStart time.Time
}
// Models returns a sorted iterator over all model names tracked by the limiter.
func (rl *RateLimiter) Models() iter.Seq[string] {
rl.mu.RLock()
defer rl.mu.RUnlock()
// Use maps.Keys and slices.Sorted for idiomatic Go 1.26+
keys := slices.Collect(maps.Keys(rl.Quotas))
for m := range rl.State {
if _, ok := rl.Quotas[m]; !ok {
keys = append(keys, m)
}
}
slices.Sort(keys)
return slices.Values(keys)
}
// Iter returns a sorted iterator over all model names and their current stats.
func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
return func(yield func(string, ModelStats) bool) {
stats := rl.AllStats()
for _, m := range slices.Sorted(maps.Keys(stats)) {
if !yield(m, stats[m]) {
return
}
}
}
}
// Stats returns current stats for a model.
func (rl *RateLimiter) Stats(model string) ModelStats {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.prune(model)
stats := ModelStats{}
quota, ok := rl.Quotas[model]
if ok {
stats.MaxRPM = quota.MaxRPM
stats.MaxTPM = quota.MaxTPM
stats.MaxRPD = quota.MaxRPD
}
if s, ok := rl.State[model]; ok {
stats.RPM = len(s.Requests)
stats.RPD = s.DayCount
stats.DayStart = s.DayStart
for _, t := range s.Tokens {
stats.TPM += t.Count
}
}
return stats
}
// AllStats returns stats for all tracked models.
func (rl *RateLimiter) AllStats() map[string]ModelStats {
rl.mu.Lock()
defer rl.mu.Unlock()
result := make(map[string]ModelStats)
// Collect all model names
for m := range rl.Quotas {
result[m] = ModelStats{}
}
for m := range rl.State {
result[m] = ModelStats{}
}
for m := range result {
rl.prune(m)
ms := ModelStats{}
if q, ok := rl.Quotas[m]; ok {
ms.MaxRPM = q.MaxRPM
ms.MaxTPM = q.MaxTPM
ms.MaxRPD = q.MaxRPD
}
if s, ok := rl.State[m]; ok && s != nil {
ms.RPM = len(s.Requests)
ms.RPD = s.DayCount
ms.DayStart = s.DayStart
ms.TPM = totalTokenCount(s.Tokens)
}
result[m] = ms
}
return result
}
// NewWithSQLite creates a SQLite-backed RateLimiter with Gemini defaults.
// The database is created at dbPath if it does not exist. Use Close() to
// release the database connection when finished.
func NewWithSQLite(dbPath string) (*RateLimiter, error) {
return NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini},
})
}
// NewWithSQLiteConfig creates a SQLite-backed RateLimiter with custom config.
// The Backend field in cfg is ignored (always "sqlite"). Use Close() to
// release the database connection when finished.
func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
store, err := newSQLiteStore(dbPath)
if err != nil {
return nil, err
}
rl := newConfiguredRateLimiter(cfg)
rl.sqlite = store
return rl, nil
}
// Close releases resources held by the RateLimiter. For YAML-backed
// limiters this is a no-op. For SQLite-backed limiters it closes the
// database connection.
func (rl *RateLimiter) Close() error {
if rl.sqlite != nil {
return rl.sqlite.close()
}
return nil
}
// MigrateYAMLToSQLite reads state from a YAML file and writes it to a new
// SQLite database. Both quotas and usage state are migrated. The SQLite
// database is created if it does not exist.
func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
// Load from YAML.
content, err := coreio.Local.Read(yamlPath)
if err != nil {
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "read", err)
}
var rl RateLimiter
if err := yaml.Unmarshal([]byte(content), &rl); err != nil {
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "unmarshal", err)
}
// Write to SQLite.
store, err := newSQLiteStore(sqlitePath)
if err != nil {
return err
}
defer store.close()
if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil {
return err
}
return nil
}
// CountTokens calls the Google API to count tokens for a prompt.
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
}
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
requestURL, err := countTokensURL(baseURL, model)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
}
reqBody := map[string]any{
"contents": []any{
map[string]any{
"parts": []any{
map[string]string{"text": text},
},
},
},
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey)
if client == nil {
client = http.DefaultClient
}
resp, err := client.Do(req)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
}
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
}
var result struct {
TotalTokens int `json:"totalTokens"`
}
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
}
return result.TotalTokens, nil
}
func newConfiguredRateLimiter(cfg Config) *RateLimiter {
rl := &RateLimiter{
Quotas: make(map[string]ModelQuota),
State: make(map[string]*UsageStats),
}
applyConfig(rl, cfg)
return rl
}
func applyConfig(rl *RateLimiter, cfg Config) {
profiles := DefaultProfiles()
providers := cfg.Providers
if len(providers) == 0 && len(cfg.Quotas) == 0 {
providers = []Provider{ProviderGemini}
}
for _, p := range providers {
if profile, ok := profiles[p]; ok {
maps.Copy(rl.Quotas, profile.Models)
}
}
maps.Copy(rl.Quotas, cfg.Quotas)
}
func normaliseBackend(backend string) (string, error) {
switch strings.ToLower(strings.TrimSpace(backend)) {
case "", backendYAML:
return backendYAML, nil
case backendSQLite:
return backendSQLite, nil
default:
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
}
}
func defaultStatePath(backend string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
fileName := defaultYAMLStateFile
if backend == backendSQLite {
fileName = defaultSQLiteStateFile
}
return filepath.Join(home, defaultStateDirName, fileName), nil
}
func safeTokenSum(a, b int) int {
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
}
func totalTokenCount(tokens []TokenEntry) int {
return safeTokenTotal(tokens)
}
func safeTokenTotal(tokens []TokenEntry) int {
const maxInt = int(^uint(0) >> 1)
total := 0
for _, token := range tokens {
count := token.Count
if count < 0 {
continue
}
if total > maxInt-count {
return maxInt
}
total += count
}
return total
}
func countTokensURL(baseURL, model string) (string, error) {
if strings.TrimSpace(model) == "" {
return "", fmt.Errorf("empty model")
}
parsed, err := url.Parse(baseURL)
if err != nil {
return "", err
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid base url")
}
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
}
func readLimitedBody(r io.Reader, limit int64) (string, error) {
body, err := io.ReadAll(io.LimitReader(r, limit+1))
if err != nil {
return "", err
}
truncated := int64(len(body)) > limit
if truncated {
body = body[:limit]
}
result := string(body)
if truncated {
result += "..."
}
return result, nil
}