refactor(ratelimit): finish ax v0.8.0 polish

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-26 18:51:54 +00:00
parent 36cc0a4750
commit ed1cdc11b2
5 changed files with 390 additions and 262 deletions

View file

@ -1,8 +1,7 @@
package ratelimit
import (
"os"
"path/filepath"
"syscall"
"testing"
"time"
@ -10,8 +9,8 @@ import (
"github.com/stretchr/testify/require"
)
func TestSQLiteErrorPaths(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "error.db")
func TestError_SQLiteErrorPaths_Bad(t *testing.T) {
dbPath := testPath(t.TempDir(), "error.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -39,17 +38,17 @@ func TestSQLiteErrorPaths(t *testing.T) {
})
}
func TestSQLiteInitErrors(t *testing.T) {
func TestError_SQLiteInitErrors_Bad(t *testing.T) {
t.Run("WAL pragma failure", func(t *testing.T) {
// This is hard to trigger without mocking sql.DB, but we can try an invalid connection string
// modernc.org/sqlite doesn't support all DSN options that might cause PRAGMA to fail but connection to succeed.
})
}
func TestPersistYAML(t *testing.T) {
func TestError_PersistYAML_Good(t *testing.T) {
t.Run("successful YAML persist and load", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "ratelimits.yaml")
path := testPath(tmpDir, "ratelimits.yaml")
rl, _ := New()
rl.filePath = path
rl.Quotas["test"] = ModelQuota{MaxRPM: 1}
@ -65,9 +64,9 @@ func TestPersistYAML(t *testing.T) {
})
}
func TestSQLiteLoadViaLimiter(t *testing.T) {
func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) {
t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "load-err.db")
dbPath := testPath(t.TempDir(), "load-err.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -79,7 +78,7 @@ func TestSQLiteLoadViaLimiter(t *testing.T) {
})
t.Run("Load returns error when loadState fails", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "load-state-err.db")
dbPath := testPath(t.TempDir(), "load-state-err.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -96,9 +95,9 @@ func TestSQLiteLoadViaLimiter(t *testing.T) {
})
}
func TestSQLitePersistViaLimiter(t *testing.T) {
func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) {
t.Run("Persist returns error when SQLite saveQuotas fails", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist-err.db")
dbPath := testPath(t.TempDir(), "persist-err.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -113,7 +112,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) {
})
t.Run("Persist returns error when SQLite saveState fails", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist-state-err.db")
dbPath := testPath(t.TempDir(), "persist-state-err.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -130,7 +129,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) {
})
}
func TestNewWithSQLiteErrors(t *testing.T) {
func TestError_NewWithSQLite_Bad(t *testing.T) {
t.Run("NewWithSQLite with invalid path", func(t *testing.T) {
_, err := NewWithSQLite("/nonexistent/deep/nested/dir/test.db")
assert.Error(t, err, "should fail with invalid path")
@ -144,9 +143,9 @@ func TestNewWithSQLiteErrors(t *testing.T) {
})
}
func TestSQLiteSaveStateErrors(t *testing.T) {
func TestError_SQLiteSaveState_Bad(t *testing.T) {
t.Run("saveState fails when tokens table is dropped", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "tokens-err.db")
dbPath := testPath(t.TempDir(), "tokens-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -166,7 +165,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
})
t.Run("saveState fails when daily table is dropped", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "daily-err.db")
dbPath := testPath(t.TempDir(), "daily-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -185,7 +184,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
})
t.Run("saveState fails on request insert with renamed column", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "req-insert-err.db")
dbPath := testPath(t.TempDir(), "req-insert-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -206,7 +205,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
})
t.Run("saveState fails on token insert with renamed column", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "tok-insert-err.db")
dbPath := testPath(t.TempDir(), "tok-insert-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -227,7 +226,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
})
t.Run("saveState fails on daily insert with renamed column", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "day-insert-err.db")
dbPath := testPath(t.TempDir(), "day-insert-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -247,9 +246,9 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
})
}
func TestSQLiteLoadStateErrors(t *testing.T) {
func TestError_SQLiteLoadState_Bad(t *testing.T) {
t.Run("loadState fails when requests table is dropped", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "req-err.db")
dbPath := testPath(t.TempDir(), "req-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -271,7 +270,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
})
t.Run("loadState fails when tokens table is dropped", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "tok-err.db")
dbPath := testPath(t.TempDir(), "tok-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -293,7 +292,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
})
t.Run("loadState fails when daily table is dropped", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "daily-load-err.db")
dbPath := testPath(t.TempDir(), "daily-load-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -314,9 +313,9 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
})
}
func TestSQLiteSaveQuotasExecError(t *testing.T) {
func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) {
t.Run("saveQuotas fails with renamed column at prepare", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "quota-exec-err.db")
dbPath := testPath(t.TempDir(), "quota-exec-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -332,7 +331,7 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) {
})
t.Run("saveQuotas fails at exec via trigger", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "quota-trigger.db")
dbPath := testPath(t.TempDir(), "quota-trigger.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -350,9 +349,9 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) {
})
}
func TestSQLiteSaveStateExecErrors(t *testing.T) {
func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
t.Run("request insert exec fails via trigger", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "trigger-req.db")
dbPath := testPath(t.TempDir(), "trigger-req.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -375,7 +374,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
})
t.Run("token insert exec fails via trigger", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "trigger-tok.db")
dbPath := testPath(t.TempDir(), "trigger-tok.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -398,7 +397,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
})
t.Run("daily insert exec fails via trigger", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "trigger-day.db")
dbPath := testPath(t.TempDir(), "trigger-day.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -420,9 +419,9 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
})
}
func TestSQLiteLoadQuotasScanError(t *testing.T) {
func TestError_SQLiteLoadQuotasScan_Bad(t *testing.T) {
t.Run("loadQuotas fails with renamed column", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "quota-scan-err.db")
dbPath := testPath(t.TempDir(), "quota-scan-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -441,26 +440,29 @@ func TestSQLiteLoadQuotasScanError(t *testing.T) {
})
}
func TestNewSQLiteStoreInReadOnlyDir(t *testing.T) {
if os.Getuid() == 0 {
func TestError_NewSQLiteStoreInReadOnlyDir_Bad(t *testing.T) {
if isRootUser() {
t.Skip("chmod restrictions do not apply to root")
}
t.Run("fails when parent directory is read-only", func(t *testing.T) {
tmpDir := t.TempDir()
readonlyDir := filepath.Join(tmpDir, "readonly")
require.NoError(t, os.MkdirAll(readonlyDir, 0555))
defer os.Chmod(readonlyDir, 0755)
readonlyDir := testPath(tmpDir, "readonly")
ensureTestDir(t, readonlyDir)
setPathMode(t, readonlyDir, 0o555)
defer func() {
_ = syscall.Chmod(readonlyDir, 0o755)
}()
dbPath := filepath.Join(readonlyDir, "test.db")
dbPath := testPath(readonlyDir, "test.db")
_, err := newSQLiteStore(dbPath)
assert.Error(t, err, "should fail when directory is read-only")
})
}
func TestSQLiteCreateSchemaError(t *testing.T) {
func TestError_SQLiteCreateSchema_Bad(t *testing.T) {
t.Run("createSchema fails on closed DB", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "schema-err.db")
dbPath := testPath(t.TempDir(), "schema-err.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
@ -473,9 +475,9 @@ func TestSQLiteCreateSchemaError(t *testing.T) {
})
}
func TestSQLiteLoadStateScanErrors(t *testing.T) {
func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
t.Run("scan daily fails with NULL values", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "scan-daily.db")
dbPath := testPath(t.TempDir(), "scan-daily.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -495,7 +497,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
})
t.Run("scan requests fails with NULL ts", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "scan-req.db")
dbPath := testPath(t.TempDir(), "scan-req.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -520,7 +522,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
})
t.Run("scan tokens fails with NULL values", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "scan-tok.db")
dbPath := testPath(t.TempDir(), "scan-tok.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -545,9 +547,9 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
})
}
func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) {
func TestError_SQLiteLoadQuotasScanWithBadSchema_Bad(t *testing.T) {
t.Run("scan fails with NULL quota values", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "scan-quota.db")
dbPath := testPath(t.TempDir(), "scan-quota.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -566,11 +568,11 @@ func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) {
})
}
func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
t.Run("saveQuotas failure during migration via trigger", func(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "with-quotas.yaml")
sqlitePath := filepath.Join(tmpDir, "migrate-quota-err.db")
yamlPath := testPath(tmpDir, "with-quotas.yaml")
sqlitePath := testPath(tmpDir, "migrate-quota-err.db")
// Write a YAML file with quotas.
yamlData := `quotas:
@ -579,7 +581,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
max_tpm: 100
max_rpd: 50
`
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
writeTestFile(t, yamlPath, yamlData)
// Pre-create DB with a trigger that aborts quota inserts.
store, err := newSQLiteStore(sqlitePath)
@ -596,8 +598,8 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
t.Run("saveState failure during migration via trigger", func(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "with-state.yaml")
sqlitePath := filepath.Join(tmpDir, "migrate-state-err.db")
yamlPath := testPath(tmpDir, "with-state.yaml")
sqlitePath := testPath(tmpDir, "migrate-state-err.db")
// Write YAML with state.
yamlData := `state:
@ -607,7 +609,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
day_start: 2026-01-01T00:00:00Z
day_count: 1
`
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
writeTestFile(t, yamlPath, yamlData)
// Pre-create DB with a trigger that aborts daily inserts.
store, err := newSQLiteStore(sqlitePath)
@ -622,13 +624,13 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
})
}
func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) {
func TestError_MigrateYAMLToSQLiteNilQuotasAndState_Good(t *testing.T) {
t.Run("YAML with empty quotas and state migrates cleanly", func(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "empty.yaml")
require.NoError(t, os.WriteFile(yamlPath, []byte("{}"), 0644))
yamlPath := testPath(tmpDir, "empty.yaml")
writeTestFile(t, yamlPath, "{}")
sqlitePath := filepath.Join(tmpDir, "empty.db")
sqlitePath := testPath(tmpDir, "empty.db")
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))
store, err := newSQLiteStore(sqlitePath)
@ -645,30 +647,18 @@ func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) {
})
}
func TestNewWithConfigUserHomeDirError(t *testing.T) {
// Unset HOME to trigger os.UserHomeDir() error.
home := os.Getenv("HOME")
os.Unsetenv("HOME")
// Also unset fallback env vars that UserHomeDir checks.
plan9Home := os.Getenv("home")
os.Unsetenv("home")
userProfile := os.Getenv("USERPROFILE")
os.Unsetenv("USERPROFILE")
defer func() {
os.Setenv("HOME", home)
if plan9Home != "" {
os.Setenv("home", plan9Home)
}
if userProfile != "" {
os.Setenv("USERPROFILE", userProfile)
}
}()
func TestError_NewWithConfigUserHomeDir_Bad(t *testing.T) {
// Clear all supported home env vars so defaultStatePath cannot resolve a home directory.
t.Setenv("CORE_HOME", "")
t.Setenv("HOME", "")
t.Setenv("home", "")
t.Setenv("USERPROFILE", "")
_, err := NewWithConfig(Config{})
assert.Error(t, err, "should fail when HOME is unset")
}
func TestPersistMarshalError(t *testing.T) {
func TestError_PersistMarshal_Good(t *testing.T) {
// yaml.Marshal on a struct with map[string]ModelQuota and map[string]*UsageStats
// should not fail in practice. We test the error path by using a type that
// yaml.Marshal cannot handle: a channel.
@ -680,20 +670,20 @@ func TestPersistMarshalError(t *testing.T) {
assert.NoError(t, rl.Persist(), "valid persist should succeed")
}
func TestMigrateErrorsExtended(t *testing.T) {
func TestError_MigrateErrorsExtended_Bad(t *testing.T) {
t.Run("unmarshal failure", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "bad.yaml")
require.NoError(t, os.WriteFile(path, []byte("invalid: yaml: ["), 0644))
err := MigrateYAMLToSQLite(path, filepath.Join(tmpDir, "out.db"))
path := testPath(tmpDir, "bad.yaml")
writeTestFile(t, path, "invalid: yaml: [")
err := MigrateYAMLToSQLite(path, testPath(tmpDir, "out.db"))
assert.Error(t, err)
assert.Contains(t, err.Error(), "ratelimit.MigrateYAMLToSQLite: unmarshal")
})
t.Run("sqlite open failure", func(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "ok.yaml")
require.NoError(t, os.WriteFile(yamlPath, []byte("quotas: {}"), 0644))
yamlPath := testPath(tmpDir, "ok.yaml")
writeTestFile(t, yamlPath, "quotas: {}")
// Use an invalid sqlite path (dir where file should be)
err := MigrateYAMLToSQLite(yamlPath, "/dev/null/not-a-db")
assert.Error(t, err)

View file

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestIterators(t *testing.T) {
func TestIter_Iterators_Good(t *testing.T) {
rl, err := NewWithConfig(Config{
Quotas: map[string]ModelQuota{
"model-c": {MaxRPM: 10},
@ -77,7 +77,7 @@ func TestIterators(t *testing.T) {
})
}
func TestIterEarlyBreak(t *testing.T) {
func TestIter_IterEarlyBreak_Good(t *testing.T) {
rl, err := NewWithConfig(Config{
Quotas: map[string]ModelQuota{
"model-a": {MaxRPM: 10},
@ -110,7 +110,7 @@ func TestIterEarlyBreak(t *testing.T) {
})
}
func TestCountTokensFull(t *testing.T) {
func TestIter_CountTokensFull_Ugly(t *testing.T) {
t.Run("empty model is rejected", func(t *testing.T) {
_, err := CountTokens(context.Background(), "key", "", "text")
assert.Error(t, err)

View file

@ -3,11 +3,11 @@ package ratelimit
import (
"context"
"io"
"io/fs"
"iter"
"maps"
"net/http"
"net/url"
"os"
"slices"
"sync"
"time"
@ -17,6 +17,8 @@ import (
)
// Provider identifies an LLM provider for quota profiles.
//
// provider := ProviderOpenAI
type Provider string
const (
@ -41,6 +43,8 @@ const (
)
// ModelQuota defines the rate limits for a specific model.
//
// quota := ModelQuota{MaxRPM: 60, MaxTPM: 90000, MaxRPD: 1000}
type ModelQuota struct {
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited)
@ -48,12 +52,18 @@ type ModelQuota struct {
}
// ProviderProfile bundles model quotas for a provider.
//
// profile := ProviderProfile{Provider: ProviderGemini, Models: DefaultProfiles()[ProviderGemini].Models}
type ProviderProfile struct {
Provider Provider `yaml:"provider"`
Models map[string]ModelQuota `yaml:"models"`
// Provider identifies the provider that owns the profile.
Provider Provider `yaml:"provider"`
// Models maps model names to quotas.
Models map[string]ModelQuota `yaml:"models"`
}
// Config controls RateLimiter initialisation.
//
// cfg := Config{Providers: []Provider{ProviderGemini}, FilePath: "/tmp/ratelimits.yaml"}
type Config struct {
// FilePath overrides the default state file location.
// If empty, defaults to ~/.core/ratelimits.yaml.
@ -73,23 +83,35 @@ type Config struct {
}
// TokenEntry records a token usage event.
//
// entry := TokenEntry{Time: time.Now(), Count: 512}
type TokenEntry struct {
Time time.Time `yaml:"time"`
Count int `yaml:"count"`
}
// UsageStats tracks usage history for a model.
//
// stats := UsageStats{DayStart: time.Now(), DayCount: 1}
type UsageStats struct {
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
DayStart time.Time `yaml:"day_start"`
DayCount int `yaml:"day_count"`
// DayStart is the start of the rolling 24-hour window.
DayStart time.Time `yaml:"day_start"`
// DayCount is the number of requests recorded in the rolling 24-hour window.
DayCount int `yaml:"day_count"`
}
// RateLimiter manages rate limits across multiple models.
//
// rl, err := New()
// if err != nil { /* handle error */ }
// defer rl.Close()
type RateLimiter struct {
mu sync.RWMutex
Quotas map[string]ModelQuota `yaml:"quotas"`
mu sync.RWMutex
// Quotas holds the configured per-model limits.
Quotas map[string]ModelQuota `yaml:"quotas"`
// State holds per-model usage windows.
State map[string]*UsageStats `yaml:"state"`
filePath string
sqlite *sqliteStore // non-nil when backend is "sqlite"
@ -97,6 +119,9 @@ type RateLimiter struct {
// DefaultProfiles returns pre-configured quota profiles for each provider.
// Values are based on published rate limits as of Feb 2026.
//
// profiles := DefaultProfiles()
// openAI := profiles[ProviderOpenAI]
func DefaultProfiles() map[Provider]ProviderProfile {
return map[Provider]ProviderProfile{
ProviderGemini: {
@ -140,6 +165,8 @@ func DefaultProfiles() map[Provider]ProviderProfile {
// New creates a new RateLimiter with Gemini defaults.
// This preserves backward compatibility -- existing callers are unaffected.
//
// rl, err := New()
func New() (*RateLimiter, error) {
return NewWithConfig(Config{
Providers: []Provider{ProviderGemini},
@ -148,6 +175,8 @@ func New() (*RateLimiter, error) {
// NewWithConfig creates a RateLimiter from explicit configuration.
// If no providers or quotas are specified, Gemini defaults are used.
//
// rl, err := NewWithConfig(Config{Providers: []Provider{ProviderAnthropic}})
func NewWithConfig(cfg Config) (*RateLimiter, error) {
backend, err := normaliseBackend(cfg.Backend)
if err != nil {
@ -177,6 +206,8 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) {
}
// SetQuota sets or updates the quota for a specific model at runtime.
//
// rl.SetQuota("gpt-4o-mini", ModelQuota{MaxRPM: 60, MaxTPM: 200000})
func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -185,6 +216,8 @@ func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
// AddProvider loads all default quotas for a provider.
// Existing quotas for models in the profile are overwritten.
//
// rl.AddProvider(ProviderOpenAI)
func (rl *RateLimiter) AddProvider(provider Provider) {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -196,6 +229,8 @@ func (rl *RateLimiter) AddProvider(provider Provider) {
}
// Load reads the state from disk (YAML) or database (SQLite).
//
// if err := rl.Load(); err != nil { /* handle error */ }
func (rl *RateLimiter) Load() error {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -205,7 +240,7 @@ func (rl *RateLimiter) Load() error {
}
content, err := readLocalFile(rl.filePath)
if os.IsNotExist(err) {
if core.Is(err, fs.ErrNotExist) {
return nil
}
if err != nil {
@ -238,6 +273,8 @@ func (rl *RateLimiter) loadSQLite() error {
// Persist writes a snapshot of the state to disk (YAML) or database (SQLite).
// It clones the state under a lock and performs I/O without blocking other callers.
//
// if err := rl.Persist(); err != nil { /* handle error */ }
func (rl *RateLimiter) Persist() error {
rl.mu.Lock()
quotas := maps.Clone(rl.Quotas)
@ -328,6 +365,9 @@ func (rl *RateLimiter) prune(model string) {
// BackgroundPrune starts a goroutine that periodically prunes all model states.
// It returns a function to stop the pruner.
//
// stop := rl.BackgroundPrune(30 * time.Second)
// defer stop()
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
if interval <= 0 {
return func() {}
@ -354,6 +394,8 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
}
// CanSend checks if a request can be sent without violating limits.
//
// ok := rl.CanSend("gemini-3-pro-preview", 1200)
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
if estimatedTokens < 0 {
return false
@ -401,6 +443,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
}
// RecordUsage records a successful API call.
//
// rl.RecordUsage("gemini-3-pro-preview", 900, 300)
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -420,6 +464,8 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
}
// WaitForCapacity blocks until capacity is available or context is cancelled.
//
// err := rl.WaitForCapacity(ctx, "gemini-3-pro-preview", 1200)
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
if tokens < 0 {
return core.E("ratelimit.WaitForCapacity", "negative tokens", nil)
@ -443,6 +489,8 @@ func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens
}
// Reset clears stats for a model (or all if model is empty).
//
// rl.Reset("gemini-3-pro-preview")
func (rl *RateLimiter) Reset(model string) {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -455,17 +503,28 @@ func (rl *RateLimiter) Reset(model string) {
}
// ModelStats represents a snapshot of usage.
//
// stats := rl.Stats("gemini-3-pro-preview")
type ModelStats struct {
RPM int
MaxRPM int
TPM int
MaxTPM int
RPD int
MaxRPD int
// RPM is the current requests-per-minute usage in the sliding window.
RPM int
// MaxRPM is the configured requests-per-minute limit.
MaxRPM int
// TPM is the current tokens-per-minute usage in the sliding window.
TPM int
// MaxTPM is the configured tokens-per-minute limit.
MaxTPM int
// RPD is the current requests-per-day usage in the rolling 24-hour window.
RPD int
// MaxRPD is the configured requests-per-day limit.
MaxRPD int
// DayStart is the start of the current rolling 24-hour window.
DayStart time.Time
}
// Models returns a sorted iterator over all model names tracked by the limiter.
//
// for model := range rl.Models() { println(model) }
func (rl *RateLimiter) Models() iter.Seq[string] {
rl.mu.RLock()
defer rl.mu.RUnlock()
@ -482,6 +541,8 @@ func (rl *RateLimiter) Models() iter.Seq[string] {
}
// Iter returns a sorted iterator over all model names and their current stats.
//
// for model, stats := range rl.Iter() { _ = stats; println(model) }
func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
return func(yield func(string, ModelStats) bool) {
stats := rl.AllStats()
@ -494,6 +555,8 @@ func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
}
// Stats returns current stats for a model.
//
// stats := rl.Stats("gemini-3-pro-preview")
func (rl *RateLimiter) Stats(model string) ModelStats {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -521,6 +584,8 @@ func (rl *RateLimiter) Stats(model string) ModelStats {
}
// AllStats returns stats for all tracked models.
//
// all := rl.AllStats()
func (rl *RateLimiter) AllStats() map[string]ModelStats {
rl.mu.Lock()
defer rl.mu.Unlock()
@ -559,6 +624,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
// NewWithSQLite creates a SQLite-backed RateLimiter with Gemini defaults.
// The database is created at dbPath if it does not exist. Use Close() to
// release the database connection when finished.
//
// rl, err := NewWithSQLite("/tmp/ratelimits.db")
func NewWithSQLite(dbPath string) (*RateLimiter, error) {
return NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini},
@ -568,6 +635,8 @@ func NewWithSQLite(dbPath string) (*RateLimiter, error) {
// NewWithSQLiteConfig creates a SQLite-backed RateLimiter with custom config.
// The Backend field in cfg is ignored (always "sqlite"). Use Close() to
// release the database connection when finished.
//
// rl, err := NewWithSQLiteConfig("/tmp/ratelimits.db", Config{Providers: []Provider{ProviderOpenAI}})
func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
store, err := newSQLiteStore(dbPath)
if err != nil {
@ -582,6 +651,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
// Close releases resources held by the RateLimiter. For YAML-backed
// limiters this is a no-op. For SQLite-backed limiters it closes the
// database connection.
//
// defer rl.Close()
func (rl *RateLimiter) Close() error {
if rl.sqlite != nil {
return rl.sqlite.close()
@ -592,6 +663,8 @@ func (rl *RateLimiter) Close() error {
// MigrateYAMLToSQLite reads state from a YAML file and writes it to a new
// SQLite database. Both quotas and usage state are migrated. The SQLite
// database is created if it does not exist.
//
// err := MigrateYAMLToSQLite("ratelimits.yaml", "ratelimits.db")
func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
// Load from YAML.
content, err := readLocalFile(yamlPath)
@ -618,6 +691,8 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
}
// CountTokens calls the Google API to count tokens for a prompt.
//
// tokens, err := CountTokens(ctx, apiKey, "gemini-3-pro-preview", prompt)
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
}
@ -722,9 +797,9 @@ func normaliseBackend(backend string) (string, error) {
}
func defaultStatePath(backend string) (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
home := currentHomeDir()
if home == "" {
return "", core.E("ratelimit.defaultStatePath", "home dir unavailable", nil)
}
fileName := defaultYAMLStateFile
@ -735,6 +810,15 @@ func defaultStatePath(backend string) (string, error) {
return core.Path(home, defaultStateDirName, fileName), nil
}
func currentHomeDir() string {
for _, key := range []string{"CORE_HOME", "HOME", "home", "USERPROFILE"} {
if value := core.Trim(core.Env(key)); value != "" {
return value
}
}
return ""
}
func safeTokenSum(a, b int) int {
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
}
@ -762,7 +846,7 @@ func safeTokenTotal(tokens []TokenEntry) int {
func countTokensURL(baseURL, model string) (string, error) {
if core.Trim(model) == "" {
return "", core.NewError("empty model")
return "", core.E("ratelimit.countTokensURL", "empty model", nil)
}
parsed, err := url.Parse(baseURL)
@ -770,7 +854,7 @@ func countTokensURL(baseURL, model string) (string, error) {
return "", err
}
if parsed.Scheme == "" || parsed.Host == "" {
return "", core.NewError("invalid base url")
return "", core.E("ratelimit.countTokensURL", "invalid base url", nil)
}
return core.Concat(core.TrimSuffix(parsed.String(), "/"), "/v1beta/models/", url.PathEscape(model), ":countTokens"), nil
@ -803,7 +887,7 @@ func readLocalFile(path string) (string, error) {
content, ok := result.Value.(string)
if !ok {
return "", core.NewError("read returned non-string")
return "", core.E("ratelimit.readLocalFile", "read returned non-string", nil)
}
return content, nil
}
@ -828,5 +912,5 @@ func resultError(result core.Result) error {
if result.Value == nil {
return nil
}
return core.NewError(core.Sprint(result.Value))
return core.E("ratelimit.resultError", core.Sprint(result.Value), nil)
}

View file

@ -2,28 +2,92 @@ package ratelimit
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"testing"
"time"
core "dappco.re/go/core"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testPath(parts ...string) string {
return core.Path(parts...)
}
func pathExists(path string) bool {
var fs core.Fs
return fs.Exists(path)
}
func writeTestFile(tb testing.TB, path, content string) {
tb.Helper()
require.NoError(tb, writeLocalFile(path, content))
}
func ensureTestDir(tb testing.TB, path string) {
tb.Helper()
require.NoError(tb, ensureDir(path))
}
func setPathMode(tb testing.TB, path string, mode uint32) {
tb.Helper()
require.NoError(tb, syscall.Chmod(path, mode))
}
func overwriteTestFile(tb testing.TB, path, content string) {
tb.Helper()
var fs core.Fs
writer := fs.Create(path)
require.NoError(tb, resultError(writer))
require.NoError(tb, resultError(core.WriteAll(writer.Value, content)))
}
func isRootUser() bool {
return syscall.Geteuid() == 0
}
func repeatString(part string, count int) string {
builder := core.NewBuilder()
for i := 0; i < count; i++ {
builder.WriteString(part)
}
return builder.String()
}
func substringCount(s, substr string) int {
if substr == "" {
return 0
}
return len(core.Split(s, substr)) - 1
}
func decodeJSONBody(tb testing.TB, r io.Reader, target any) {
tb.Helper()
data, err := io.ReadAll(r)
require.NoError(tb, err)
require.NoError(tb, resultError(core.JSONUnmarshal(data, target)))
}
func writeJSONBody(tb testing.TB, w io.Writer, value any) {
tb.Helper()
_, err := io.WriteString(w, core.JSONMarshalString(value))
require.NoError(tb, err)
}
// newTestLimiter returns a RateLimiter with file path set to a temp directory.
func newTestLimiter(t *testing.T) *RateLimiter {
t.Helper()
rl, err := New()
require.NoError(t, err)
rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml")
rl.filePath = testPath(t.TempDir(), "ratelimits.yaml")
return rl
}
@ -41,7 +105,7 @@ func (errReader) Read([]byte) (int, error) {
// --- Phase 0: CanSend boundary conditions ---
func TestCanSend(t *testing.T) {
func TestRatelimit_CanSend_Good(t *testing.T) {
t.Run("fresh state allows send", func(t *testing.T) {
rl := newTestLimiter(t)
model := "test-model"
@ -189,7 +253,7 @@ func TestCanSend(t *testing.T) {
// --- Phase 0: Sliding window / prune tests ---
func TestPrune(t *testing.T) {
func TestRatelimit_Prune_Good(t *testing.T) {
t.Run("removes old entries", func(t *testing.T) {
rl := newTestLimiter(t)
model := "test-prune"
@ -304,7 +368,7 @@ func TestPrune(t *testing.T) {
// --- Phase 0: RecordUsage ---
func TestRecordUsage(t *testing.T) {
func TestRatelimit_RecordUsage_Good(t *testing.T) {
t.Run("records into fresh state", func(t *testing.T) {
rl := newTestLimiter(t)
model := "record-fresh"
@ -375,7 +439,7 @@ func TestRecordUsage(t *testing.T) {
// --- Phase 0: Reset ---
func TestReset(t *testing.T) {
func TestRatelimit_Reset_Good(t *testing.T) {
t.Run("reset single model", func(t *testing.T) {
rl := newTestLimiter(t)
rl.RecordUsage("model-a", 10, 10)
@ -409,7 +473,7 @@ func TestReset(t *testing.T) {
// --- Phase 0: WaitForCapacity ---
func TestWaitForCapacity(t *testing.T) {
func TestRatelimit_WaitForCapacity_Good(t *testing.T) {
t.Run("context cancelled returns error", func(t *testing.T) {
rl := newTestLimiter(t)
model := "wait-cancel"
@ -467,7 +531,7 @@ func TestWaitForCapacity(t *testing.T) {
})
}
func TestNilUsageStats(t *testing.T) {
func TestRatelimit_NilUsageStats_Ugly(t *testing.T) {
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
rl := newTestLimiter(t)
model := "nil-cansend"
@ -514,7 +578,7 @@ func TestNilUsageStats(t *testing.T) {
// --- Phase 0: Stats ---
func TestStats(t *testing.T) {
func TestRatelimit_Stats_Good(t *testing.T) {
t.Run("returns stats for known model with usage", func(t *testing.T) {
rl := newTestLimiter(t)
model := "stats-test"
@ -554,7 +618,7 @@ func TestStats(t *testing.T) {
// --- Phase 0: AllStats ---
func TestAllStats(t *testing.T) {
func TestRatelimit_AllStats_Good(t *testing.T) {
t.Run("includes all default quotas plus state-only models", func(t *testing.T) {
rl := newTestLimiter(t)
rl.RecordUsage("gemini-3-pro-preview", 1000, 500)
@ -612,10 +676,10 @@ func TestAllStats(t *testing.T) {
// --- Phase 0: Persist and Load ---
func TestPersistAndLoad(t *testing.T) {
func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
t.Run("round-trip preserves state", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "ratelimits.yaml")
path := testPath(tmpDir, "ratelimits.yaml")
rl1, err := New()
require.NoError(t, err)
@ -638,7 +702,7 @@ func TestPersistAndLoad(t *testing.T) {
t.Run("load from non-existent file is not an error", func(t *testing.T) {
rl := newTestLimiter(t)
rl.filePath = filepath.Join(t.TempDir(), "does-not-exist.yaml")
rl.filePath = testPath(t.TempDir(), "does-not-exist.yaml")
err := rl.Load()
assert.NoError(t, err, "loading non-existent file should not error")
@ -646,8 +710,8 @@ func TestPersistAndLoad(t *testing.T) {
t.Run("load from corrupt YAML returns error", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "corrupt.yaml")
require.NoError(t, os.WriteFile(path, []byte("{{{{invalid yaml!!!!"), 0644))
path := testPath(tmpDir, "corrupt.yaml")
writeTestFile(t, path, "{{{{invalid yaml!!!!")
rl := newTestLimiter(t)
rl.filePath = path
@ -657,13 +721,13 @@ func TestPersistAndLoad(t *testing.T) {
})
t.Run("load from unreadable file returns error", func(t *testing.T) {
if os.Getuid() == 0 {
if isRootUser() {
t.Skip("chmod 000 does not restrict root")
}
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "unreadable.yaml")
require.NoError(t, os.WriteFile(path, []byte("quotas: {}"), 0644))
require.NoError(t, os.Chmod(path, 0000))
path := testPath(tmpDir, "unreadable.yaml")
writeTestFile(t, path, "quotas: {}")
setPathMode(t, path, 0o000)
rl := newTestLimiter(t)
rl.filePath = path
@ -672,12 +736,12 @@ func TestPersistAndLoad(t *testing.T) {
assert.Error(t, err, "unreadable file should produce an error")
// Clean up permissions for temp dir cleanup
_ = os.Chmod(path, 0644)
_ = syscall.Chmod(path, 0o644)
})
t.Run("persist to nested non-existent directory creates it", func(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "nested", "deep", "ratelimits.yaml")
path := testPath(tmpDir, "nested", "deep", "ratelimits.yaml")
rl := newTestLimiter(t)
rl.filePath = path
@ -686,32 +750,32 @@ func TestPersistAndLoad(t *testing.T) {
err := rl.Persist()
assert.NoError(t, err, "should create nested directories")
_, statErr := os.Stat(path)
assert.NoError(t, statErr, "file should exist")
assert.True(t, pathExists(path), "file should exist")
})
t.Run("persist to unwritable directory returns error", func(t *testing.T) {
if os.Getuid() == 0 {
if isRootUser() {
t.Skip("chmod 0555 does not restrict root")
}
tmpDir := t.TempDir()
unwritable := filepath.Join(tmpDir, "readonly")
require.NoError(t, os.MkdirAll(unwritable, 0555))
unwritable := testPath(tmpDir, "readonly")
ensureTestDir(t, unwritable)
setPathMode(t, unwritable, 0o555)
rl := newTestLimiter(t)
rl.filePath = filepath.Join(unwritable, "sub", "ratelimits.yaml")
rl.filePath = testPath(unwritable, "sub", "ratelimits.yaml")
err := rl.Persist()
assert.Error(t, err, "should fail when directory is unwritable")
// Clean up
_ = os.Chmod(unwritable, 0755)
_ = syscall.Chmod(unwritable, 0o755)
})
}
// --- Phase 0: Default quotas ---
func TestDefaultQuotas(t *testing.T) {
func TestRatelimit_DefaultQuotas_Good(t *testing.T) {
rl := newTestLimiter(t)
tests := []struct {
@ -740,7 +804,7 @@ func TestDefaultQuotas(t *testing.T) {
// --- Phase 0: Concurrent access (race test) ---
func TestConcurrentAccess(t *testing.T) {
func TestRatelimit_ConcurrentAccess_Good(t *testing.T) {
rl := newTestLimiter(t)
model := "concurrent-test"
rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 10000}
@ -766,7 +830,7 @@ func TestConcurrentAccess(t *testing.T) {
assert.Equal(t, expected, stats.RPD, "all recordings should be counted")
}
func TestConcurrentResetAndRecord(t *testing.T) {
func TestRatelimit_ConcurrentResetAndRecord_Ugly(t *testing.T) {
rl := newTestLimiter(t)
model := "concurrent-reset"
rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000}
@ -804,7 +868,7 @@ func TestConcurrentResetAndRecord(t *testing.T) {
// No assertion needed -- if we get here without -race flagging, mutex is sound
}
func TestBackgroundPrune(t *testing.T) {
func TestRatelimit_BackgroundPrune_Good(t *testing.T) {
rl := newTestLimiter(t)
model := "prune-me"
rl.Quotas[model] = ModelQuota{MaxRPM: 100}
@ -843,7 +907,7 @@ func TestBackgroundPrune(t *testing.T) {
// --- Phase 0: CountTokens (with mock HTTP server) ---
func TestCountTokens(t *testing.T) {
func TestRatelimit_CountTokens_Ugly(t *testing.T) {
t.Run("successful token count", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, http.MethodPost, r.Method)
@ -858,13 +922,13 @@ func TestCountTokens(t *testing.T) {
} `json:"parts"`
} `json:"contents"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
decodeJSONBody(t, r.Body, &body)
require.Len(t, body.Contents, 1)
require.Len(t, body.Contents[0].Parts, 1)
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
writeJSONBody(t, w, map[string]int{"totalTokens": 42})
}))
defer server.Close()
@ -878,7 +942,7 @@ func TestCountTokens(t *testing.T) {
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
assert.Empty(t, r.URL.RawQuery)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
writeJSONBody(t, w, map[string]int{"totalTokens": 7})
}))
defer server.Close()
@ -888,10 +952,10 @@ func TestCountTokens(t *testing.T) {
})
t.Run("API error body is truncated", func(t *testing.T) {
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
largeBody := repeatString("x", countTokensErrorBodyLimit+256)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, err := fmt.Fprint(w, largeBody)
_, err := io.WriteString(w, largeBody)
require.NoError(t, err)
}))
defer server.Close()
@ -899,7 +963,7 @@ func TestCountTokens(t *testing.T) {
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "api error status 401")
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
assert.True(t, substringCount(err.Error(), "x") < len(largeBody), "error body should be bounded")
assert.Contains(t, err.Error(), "...")
})
@ -953,7 +1017,7 @@ func TestCountTokens(t *testing.T) {
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
writeJSONBody(t, w, map[string]int{"totalTokens": 11})
}))
defer server.Close()
@ -969,8 +1033,8 @@ func TestCountTokens(t *testing.T) {
})
}
func TestPersistSkipsNilState(t *testing.T) {
path := filepath.Join(t.TempDir(), "nil-state.yaml")
func TestRatelimit_PersistSkipsNilState_Good(t *testing.T) {
path := testPath(t.TempDir(), "nil-state.yaml")
rl, err := New()
require.NoError(t, err)
@ -986,7 +1050,7 @@ func TestPersistSkipsNilState(t *testing.T) {
assert.NotContains(t, rl2.State, "nil-model")
}
func TestTokenTotals(t *testing.T) {
func TestRatelimit_TokenTotals_Good(t *testing.T) {
maxInt := int(^uint(0) >> 1)
assert.Equal(t, 25, safeTokenSum(-100, 25))
@ -1056,7 +1120,7 @@ func BenchmarkCanSendConcurrent(b *testing.B) {
// --- Phase 1: Provider profiles and NewWithConfig ---
func TestDefaultProfiles(t *testing.T) {
func TestRatelimit_DefaultProfiles_Good(t *testing.T) {
profiles := DefaultProfiles()
t.Run("contains all four providers", func(t *testing.T) {
@ -1097,10 +1161,10 @@ func TestDefaultProfiles(t *testing.T) {
})
}
func TestNewWithConfig(t *testing.T) {
func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
t.Run("empty config defaults to Gemini", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
})
require.NoError(t, err)
@ -1110,7 +1174,7 @@ func TestNewWithConfig(t *testing.T) {
t.Run("single provider loads only its models", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Providers: []Provider{ProviderOpenAI},
})
require.NoError(t, err)
@ -1124,7 +1188,7 @@ func TestNewWithConfig(t *testing.T) {
t.Run("multiple providers merge models", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Providers: []Provider{ProviderGemini, ProviderAnthropic},
})
require.NoError(t, err)
@ -1140,7 +1204,7 @@ func TestNewWithConfig(t *testing.T) {
t.Run("explicit quotas override provider defaults", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Providers: []Provider{ProviderGemini},
Quotas: map[string]ModelQuota{
"gemini-3-pro-preview": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777},
@ -1156,7 +1220,7 @@ func TestNewWithConfig(t *testing.T) {
t.Run("explicit quotas without providers", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Quotas: map[string]ModelQuota{
"my-custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
},
@ -1169,7 +1233,7 @@ func TestNewWithConfig(t *testing.T) {
})
t.Run("custom file path is respected", func(t *testing.T) {
customPath := filepath.Join(t.TempDir(), "custom", "limits.yaml")
customPath := testPath(t.TempDir(), "custom", "limits.yaml")
rl, err := NewWithConfig(Config{
FilePath: customPath,
Providers: []Provider{ProviderLocal},
@ -1179,13 +1243,12 @@ func TestNewWithConfig(t *testing.T) {
rl.RecordUsage("test", 1, 1)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(customPath)
assert.NoError(t, statErr, "file should be created at custom path")
assert.True(t, pathExists(customPath), "file should be created at custom path")
})
t.Run("unknown provider is silently skipped", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Providers: []Provider{"nonexistent-provider"},
})
require.NoError(t, err)
@ -1194,7 +1257,7 @@ func TestNewWithConfig(t *testing.T) {
t.Run("local provider with custom quotas", func(t *testing.T) {
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
FilePath: testPath(t.TempDir(), "test.yaml"),
Providers: []Provider{ProviderLocal},
Quotas: map[string]ModelQuota{
"llama-3.3-70b": {MaxRPM: 5, MaxTPM: 50000, MaxRPD: 0},
@ -1224,11 +1287,11 @@ func TestNewWithConfig(t *testing.T) {
rl, err := NewWithConfig(Config{})
require.NoError(t, err)
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
assert.Equal(t, testPath(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
})
}
func TestNewBackwardCompatibility(t *testing.T) {
func TestRatelimit_NewBackwardCompatibility_Good(t *testing.T) {
// New() should produce the exact same result as before Phase 1
rl, err := New()
require.NoError(t, err)
@ -1251,7 +1314,7 @@ func TestNewBackwardCompatibility(t *testing.T) {
}
}
func TestSetQuota(t *testing.T) {
func TestRatelimit_SetQuota_Good(t *testing.T) {
t.Run("adds new model quota", func(t *testing.T) {
rl := newTestLimiter(t)
rl.SetQuota("custom-model", ModelQuota{MaxRPM: 42, MaxTPM: 9999, MaxRPD: 100})
@ -1279,7 +1342,7 @@ func TestSetQuota(t *testing.T) {
wg.Add(1)
go func(n int) {
defer wg.Done()
model := fmt.Sprintf("model-%d", n)
model := core.Sprintf("model-%d", n)
rl.SetQuota(model, ModelQuota{MaxRPM: n, MaxTPM: n * 100, MaxRPD: n * 10})
}(i)
}
@ -1289,7 +1352,7 @@ func TestSetQuota(t *testing.T) {
})
}
func TestAddProvider(t *testing.T) {
func TestRatelimit_AddProvider_Good(t *testing.T) {
t.Run("adds OpenAI models to existing limiter", func(t *testing.T) {
rl := newTestLimiter(t) // starts with Gemini defaults
geminiCount := len(rl.Quotas)
@ -1351,7 +1414,7 @@ func TestAddProvider(t *testing.T) {
})
}
func TestProviderConstants(t *testing.T) {
func TestRatelimit_ProviderConstants_Good(t *testing.T) {
// Verify the string values are stable (they may be used in YAML configs)
assert.Equal(t, Provider("gemini"), ProviderGemini)
assert.Equal(t, Provider("openai"), ProviderOpenAI)
@ -1361,7 +1424,7 @@ func TestProviderConstants(t *testing.T) {
// --- Phase 0 addendum: Additional concurrent and multi-model race tests ---
func TestConcurrentMultipleModels(t *testing.T) {
func TestRatelimit_ConcurrentMultipleModels_Good(t *testing.T) {
rl := newTestLimiter(t)
models := []string{"model-a", "model-b", "model-c", "model-d", "model-e"}
for _, m := range models {
@ -1391,9 +1454,9 @@ func TestConcurrentMultipleModels(t *testing.T) {
}
}
func TestConcurrentPersistAndLoad(t *testing.T) {
func TestRatelimit_ConcurrentPersistAndLoad_Ugly(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "concurrent.yaml")
path := testPath(tmpDir, "concurrent.yaml")
rl := newTestLimiter(t)
rl.filePath = path
@ -1425,7 +1488,7 @@ func TestConcurrentPersistAndLoad(t *testing.T) {
// No panics or data races = pass
}
func TestConcurrentAllStatsAndRecordUsage(t *testing.T) {
func TestRatelimit_ConcurrentAllStatsAndRecordUsage_Good(t *testing.T) {
rl := newTestLimiter(t)
models := []string{"stats-a", "stats-b", "stats-c"}
for _, m := range models {
@ -1456,7 +1519,7 @@ func TestConcurrentAllStatsAndRecordUsage(t *testing.T) {
wg.Wait()
}
func TestConcurrentWaitForCapacityAndRecordUsage(t *testing.T) {
func TestRatelimit_ConcurrentWaitForCapacityAndRecordUsage_Good(t *testing.T) {
rl := newTestLimiter(t)
model := "race-wait"
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 10000000, MaxRPD: 10000}
@ -1553,7 +1616,7 @@ func BenchmarkAllStats(b *testing.B) {
func BenchmarkPersist(b *testing.B) {
tmpDir := b.TempDir()
path := filepath.Join(tmpDir, "bench.yaml")
path := testPath(tmpDir, "bench.yaml")
rl, _ := New()
rl.filePath = path
@ -1574,10 +1637,10 @@ func BenchmarkPersist(b *testing.B) {
}
}
func TestEndToEndMultiProvider(t *testing.T) {
func TestRatelimit_EndToEndMultiProvider_Good(t *testing.T) {
// Simulate a real-world scenario: limiter for both Gemini and Anthropic
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "multi.yaml"),
FilePath: testPath(t.TempDir(), "multi.yaml"),
Providers: []Provider{ProviderGemini, ProviderAnthropic},
})
require.NoError(t, err)

View file

@ -1,8 +1,6 @@
package ratelimit
import (
"os"
"path/filepath"
"sync"
"testing"
"time"
@ -14,18 +12,17 @@ import (
// --- Phase 2: SQLite basic tests ---
func TestNewSQLiteStore_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "test.db")
func TestSQLite_NewSQLiteStore_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "test.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
// Verify the database file was created.
_, statErr := os.Stat(dbPath)
assert.NoError(t, statErr, "database file should exist")
assert.True(t, pathExists(dbPath), "database file should exist")
}
func TestNewSQLiteStore_Bad(t *testing.T) {
func TestSQLite_NewSQLiteStore_Bad(t *testing.T) {
t.Run("invalid path returns error", func(t *testing.T) {
// Path inside a non-existent directory with no parent.
_, err := newSQLiteStore("/nonexistent/deep/nested/dir/test.db")
@ -33,8 +30,8 @@ func TestNewSQLiteStore_Bad(t *testing.T) {
})
}
func TestSQLiteQuotasRoundTrip_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "quotas.db")
func TestSQLite_QuotasRoundTrip_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "quotas.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -60,8 +57,8 @@ func TestSQLiteQuotasRoundTrip_Good(t *testing.T) {
}
}
func TestSQLiteQuotasUpsert_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "upsert.db")
func TestSQLite_QuotasUpsert_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "upsert.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -85,8 +82,8 @@ func TestSQLiteQuotasUpsert_Good(t *testing.T) {
assert.Equal(t, 777, q.MaxRPD, "should have updated RPD")
}
func TestSQLiteStateRoundTrip_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "state.db")
func TestSQLite_StateRoundTrip_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "state.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -144,8 +141,8 @@ func TestSQLiteStateRoundTrip_Good(t *testing.T) {
}
}
func TestSQLiteStateOverwrite_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "overwrite.db")
func TestSQLite_StateOverwrite_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "overwrite.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -182,8 +179,8 @@ func TestSQLiteStateOverwrite_Good(t *testing.T) {
assert.Len(t, b.Requests, 1)
}
func TestSQLiteEmptyState_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "empty.db")
func TestSQLite_EmptyState_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "empty.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -198,8 +195,8 @@ func TestSQLiteEmptyState_Good(t *testing.T) {
assert.Empty(t, state, "should return empty state from fresh DB")
}
func TestSQLiteClose_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "close.db")
func TestSQLite_Close_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "close.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
@ -208,8 +205,8 @@ func TestSQLiteClose_Good(t *testing.T) {
// --- Phase 2: SQLite integration tests ---
func TestNewWithSQLite_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "limiter.db")
func TestSQLite_NewWithSQLite_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "limiter.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
defer rl.Close()
@ -222,8 +219,8 @@ func TestNewWithSQLite_Good(t *testing.T) {
assert.NotNil(t, rl.sqlite, "SQLite store should be initialised")
}
func TestNewWithSQLiteConfig_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "config.db")
func TestSQLite_NewWithSQLiteConfig_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "config.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderAnthropic},
Quotas: map[string]ModelQuota{
@ -243,8 +240,8 @@ func TestNewWithSQLiteConfig_Good(t *testing.T) {
assert.False(t, hasGemini, "should not have Gemini models")
}
func TestSQLitePersistAndLoad_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist.db")
func TestSQLite_PersistAndLoad_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "persist.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -272,8 +269,8 @@ func TestSQLitePersistAndLoad_Good(t *testing.T) {
assert.Equal(t, 500, stats.MaxRPD)
}
func TestSQLitePersistMultipleModels_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "multi.db")
func TestSQLite_PersistMultipleModels_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "multi.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini, ProviderAnthropic},
})
@ -302,8 +299,8 @@ func TestSQLitePersistMultipleModels_Good(t *testing.T) {
assert.Equal(t, 400, claude.TPM)
}
func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "record.db")
func TestSQLite_RecordUsageThenPersistReload_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "record.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
@ -340,7 +337,7 @@ func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) {
assert.Equal(t, 1000, stats2.TPM, "TPM should survive reload")
}
func TestSQLiteClose_Good_NoOp(t *testing.T) {
func TestSQLite_CloseNoOp_Good(t *testing.T) {
// Close on YAML-backed limiter is a no-op.
rl := newTestLimiter(t)
assert.NoError(t, rl.Close(), "Close on YAML limiter should be no-op")
@ -348,8 +345,8 @@ func TestSQLiteClose_Good_NoOp(t *testing.T) {
// --- Phase 2: Concurrent SQLite ---
func TestSQLiteConcurrent_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "concurrent.db")
func TestSQLite_Concurrent_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "concurrent.db")
rl, err := NewWithSQLite(dbPath)
require.NoError(t, err)
defer rl.Close()
@ -398,10 +395,10 @@ func TestSQLiteConcurrent_Good(t *testing.T) {
// --- Phase 2: YAML backward compatibility ---
func TestYAMLBackwardCompat_Good(t *testing.T) {
func TestSQLite_YAMLBackwardCompat_Good(t *testing.T) {
// Verify that the default YAML backend still works after SQLite additions.
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "compat.yaml")
path := testPath(tmpDir, "compat.yaml")
rl1, err := New()
require.NoError(t, err)
@ -425,18 +422,18 @@ func TestYAMLBackwardCompat_Good(t *testing.T) {
assert.Equal(t, 200, stats.TPM)
}
func TestConfigBackendDefault_Good(t *testing.T) {
func TestSQLite_ConfigBackendDefault_Good(t *testing.T) {
// Empty Backend string should default to YAML behaviour.
rl, err := NewWithConfig(Config{
FilePath: filepath.Join(t.TempDir(), "default.yaml"),
FilePath: testPath(t.TempDir(), "default.yaml"),
})
require.NoError(t, err)
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
}
func TestConfigBackendSQLite_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
func TestSQLite_ConfigBackendSQLite_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "config-backend.db")
rl, err := NewWithConfig(Config{
Backend: backendSQLite,
FilePath: dbPath,
@ -451,11 +448,10 @@ func TestConfigBackendSQLite_Good(t *testing.T) {
rl.RecordUsage("backend-model", 10, 10)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(dbPath)
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
assert.True(t, pathExists(dbPath), "sqlite backend should persist to the configured DB path")
}
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
func TestSQLite_ConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", "")
@ -470,16 +466,15 @@ func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
require.NotNil(t, rl.sqlite)
require.NoError(t, rl.Persist())
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
assert.True(t, pathExists(testPath(home, defaultStateDirName, defaultSQLiteStateFile)), "sqlite backend should use the default home DB path")
}
// --- Phase 2: MigrateYAMLToSQLite ---
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
func TestSQLite_MigrateYAMLToSQLite_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "state.yaml")
sqlitePath := filepath.Join(tmpDir, "migrated.db")
yamlPath := testPath(tmpDir, "state.yaml")
sqlitePath := testPath(tmpDir, "migrated.db")
// Create a YAML-backed limiter with state.
rl, err := New()
@ -515,26 +510,26 @@ func TestMigrateYAMLToSQLite_Good(t *testing.T) {
assert.Equal(t, 2, stats.RPD, "should have 2 daily requests")
}
func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
func TestSQLite_MigrateYAMLToSQLite_Bad(t *testing.T) {
t.Run("non-existent YAML file", func(t *testing.T) {
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", filepath.Join(t.TempDir(), "out.db"))
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", testPath(t.TempDir(), "out.db"))
assert.Error(t, err, "should fail with non-existent YAML file")
})
t.Run("corrupt YAML file", func(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "corrupt.yaml")
require.NoError(t, os.WriteFile(yamlPath, []byte("{{{{not yaml!"), 0644))
yamlPath := testPath(tmpDir, "corrupt.yaml")
writeTestFile(t, yamlPath, "{{{{not yaml!")
err := MigrateYAMLToSQLite(yamlPath, filepath.Join(tmpDir, "out.db"))
err := MigrateYAMLToSQLite(yamlPath, testPath(tmpDir, "out.db"))
assert.Error(t, err, "should fail with corrupt YAML")
})
}
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
sqlitePath := filepath.Join(tmpDir, "atomic.db")
yamlPath := testPath(tmpDir, "atomic.yaml")
sqlitePath := testPath(tmpDir, "atomic.db")
now := time.Now().UTC()
store, err := newSQLiteStore(sqlitePath)
@ -573,7 +568,7 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
}
data, err := yaml.Marshal(migrated)
require.NoError(t, err)
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
writeTestFile(t, yamlPath, string(data))
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
require.Error(t, err)
@ -594,10 +589,10 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
assert.NotContains(t, state, "new-model")
}
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
func TestSQLite_MigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "full.yaml")
sqlitePath := filepath.Join(tmpDir, "full.db")
yamlPath := testPath(tmpDir, "full.yaml")
sqlitePath := testPath(tmpDir, "full.db")
// Create a full YAML state with all Gemini models.
rl, err := New()
@ -626,12 +621,12 @@ func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
// --- Phase 2: Corrupt DB recovery ---
func TestSQLiteCorruptDB_Ugly(t *testing.T) {
func TestSQLite_CorruptDB_Ugly(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "corrupt.db")
dbPath := testPath(tmpDir, "corrupt.db")
// Write garbage to the DB file.
require.NoError(t, os.WriteFile(dbPath, []byte("THIS IS NOT A SQLITE DATABASE"), 0644))
writeTestFile(t, dbPath, "THIS IS NOT A SQLITE DATABASE")
// Opening a corrupt DB may succeed (sqlite is lazy about validation),
// but operations on it should fail gracefully.
@ -648,9 +643,9 @@ func TestSQLiteCorruptDB_Ugly(t *testing.T) {
assert.Error(t, err, "loading from corrupt DB should return an error")
}
func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
func TestSQLite_TruncatedDB_Ugly(t *testing.T) {
tmpDir := t.TempDir()
dbPath := filepath.Join(tmpDir, "truncated.db")
dbPath := testPath(tmpDir, "truncated.db")
// Create a valid DB first.
store, err := newSQLiteStore(dbPath)
@ -661,11 +656,7 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
require.NoError(t, store.close())
// Truncate the file to simulate corruption.
f, err := os.OpenFile(dbPath, os.O_WRONLY|os.O_TRUNC, 0644)
require.NoError(t, err)
_, err = f.Write([]byte("TRUNC"))
require.NoError(t, err)
require.NoError(t, f.Close())
overwriteTestFile(t, dbPath, "TRUNC")
// Opening should either fail or operations should fail.
store2, err := newSQLiteStore(dbPath)
@ -679,9 +670,9 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
assert.Error(t, err, "loading from truncated DB should return an error")
}
func TestSQLiteEmptyModelState_Good(t *testing.T) {
func TestSQLite_EmptyModelState_Good(t *testing.T) {
// State with no requests or tokens but with a daily counter.
dbPath := filepath.Join(t.TempDir(), "empty-state.db")
dbPath := testPath(t.TempDir(), "empty-state.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
@ -708,8 +699,8 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) {
// --- Phase 2: End-to-end with persist cycle ---
func TestSQLiteEndToEnd_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "e2e.db")
func TestSQLite_EndToEnd_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "e2e.db")
// Session 1: Create limiter, record usage, persist.
rl1, err := NewWithSQLiteConfig(dbPath, Config{
@ -752,8 +743,8 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
assert.Equal(t, 5, custom.MaxRPM)
}
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "replace.db")
func TestSQLite_LoadReplacesPersistedSnapshot_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "replace.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
@ -788,8 +779,8 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
}
func TestSQLitePersistAtomic_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
func TestSQLite_PersistAtomic_Good(t *testing.T) {
dbPath := testPath(t.TempDir(), "persist-atomic.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
@ -827,7 +818,7 @@ func TestSQLitePersistAtomic_Good(t *testing.T) {
// --- Phase 2: Benchmark ---
func BenchmarkSQLitePersist(b *testing.B) {
dbPath := filepath.Join(b.TempDir(), "bench.db")
dbPath := testPath(b.TempDir(), "bench.db")
rl, err := NewWithSQLite(dbPath)
if err != nil {
b.Fatal(err)
@ -852,7 +843,7 @@ func BenchmarkSQLitePersist(b *testing.B) {
}
func BenchmarkSQLiteLoad(b *testing.B) {
dbPath := filepath.Join(b.TempDir(), "bench-load.db")
dbPath := testPath(b.TempDir(), "bench-load.db")
rl, err := NewWithSQLite(dbPath)
if err != nil {
b.Fatal(err)
@ -883,10 +874,10 @@ func BenchmarkSQLiteLoad(b *testing.B) {
// TestMigrateYAMLToSQLiteWithFullState tests migration of a realistic YAML
// file that contains the full serialised RateLimiter struct.
func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
func TestSQLite_MigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
tmpDir := t.TempDir()
yamlPath := filepath.Join(tmpDir, "realistic.yaml")
sqlitePath := filepath.Join(tmpDir, "realistic.db")
yamlPath := testPath(tmpDir, "realistic.yaml")
sqlitePath := testPath(tmpDir, "realistic.db")
now := time.Now()
@ -919,7 +910,7 @@ func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
data, err := yaml.Marshal(rl)
require.NoError(t, err)
require.NoError(t, os.WriteFile(yamlPath, data, 0644))
writeTestFile(t, yamlPath, string(data))
// Migrate.
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))