From ed1cdc11b2187187986b39e8586bc1c41291bc57 Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 26 Mar 2026 18:51:54 +0000 Subject: [PATCH] refactor(ratelimit): finish ax v0.8.0 polish Co-Authored-By: Virgil --- error_test.go | 156 ++++++++++++++++------------------ iter_test.go | 6 +- ratelimit.go | 126 +++++++++++++++++++++++----- ratelimit_test.go | 209 ++++++++++++++++++++++++++++++---------------- sqlite_test.go | 155 ++++++++++++++++------------------ 5 files changed, 390 insertions(+), 262 deletions(-) diff --git a/error_test.go b/error_test.go index 05166b2..96350e8 100644 --- a/error_test.go +++ b/error_test.go @@ -1,8 +1,7 @@ package ratelimit import ( - "os" - "path/filepath" + "syscall" "testing" "time" @@ -10,8 +9,8 @@ import ( "github.com/stretchr/testify/require" ) -func TestSQLiteErrorPaths(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "error.db") +func TestError_SQLiteErrorPaths_Bad(t *testing.T) { + dbPath := testPath(t.TempDir(), "error.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -39,17 +38,17 @@ func TestSQLiteErrorPaths(t *testing.T) { }) } -func TestSQLiteInitErrors(t *testing.T) { +func TestError_SQLiteInitErrors_Bad(t *testing.T) { t.Run("WAL pragma failure", func(t *testing.T) { // This is hard to trigger without mocking sql.DB, but we can try an invalid connection string // modernc.org/sqlite doesn't support all DSN options that might cause PRAGMA to fail but connection to succeed. }) } -func TestPersistYAML(t *testing.T) { +func TestError_PersistYAML_Good(t *testing.T) { t.Run("successful YAML persist and load", func(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "ratelimits.yaml") + path := testPath(tmpDir, "ratelimits.yaml") rl, _ := New() rl.filePath = path rl.Quotas["test"] = ModelQuota{MaxRPM: 1} @@ -65,9 +64,9 @@ func TestPersistYAML(t *testing.T) { }) } -func TestSQLiteLoadViaLimiter(t *testing.T) { +func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) { t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "load-err.db") + dbPath := testPath(t.TempDir(), "load-err.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -79,7 +78,7 @@ func TestSQLiteLoadViaLimiter(t *testing.T) { }) t.Run("Load returns error when loadState fails", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "load-state-err.db") + dbPath := testPath(t.TempDir(), "load-state-err.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -96,9 +95,9 @@ func TestSQLiteLoadViaLimiter(t *testing.T) { }) } -func TestSQLitePersistViaLimiter(t *testing.T) { +func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) { t.Run("Persist returns error when SQLite saveQuotas fails", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "persist-err.db") + dbPath := testPath(t.TempDir(), "persist-err.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -113,7 +112,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) { }) t.Run("Persist returns error when SQLite saveState fails", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "persist-state-err.db") + dbPath := testPath(t.TempDir(), "persist-state-err.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -130,7 +129,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) { }) } -func TestNewWithSQLiteErrors(t *testing.T) { +func TestError_NewWithSQLite_Bad(t *testing.T) { t.Run("NewWithSQLite with invalid path", func(t *testing.T) { _, err := NewWithSQLite("/nonexistent/deep/nested/dir/test.db") assert.Error(t, err, "should fail with invalid path") @@ -144,9 +143,9 @@ func TestNewWithSQLiteErrors(t *testing.T) { }) } -func TestSQLiteSaveStateErrors(t *testing.T) { +func TestError_SQLiteSaveState_Bad(t *testing.T) { t.Run("saveState fails when tokens table is dropped", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "tokens-err.db") + dbPath := testPath(t.TempDir(), "tokens-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -166,7 +165,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) { }) t.Run("saveState fails when daily table is dropped", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "daily-err.db") + dbPath := testPath(t.TempDir(), "daily-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -185,7 +184,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) { }) t.Run("saveState fails on request insert with renamed column", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "req-insert-err.db") + dbPath := testPath(t.TempDir(), "req-insert-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -206,7 +205,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) { }) t.Run("saveState fails on token insert with renamed column", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "tok-insert-err.db") + dbPath := testPath(t.TempDir(), "tok-insert-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -227,7 +226,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) { }) t.Run("saveState fails on daily insert with renamed column", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "day-insert-err.db") + dbPath := testPath(t.TempDir(), "day-insert-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -247,9 +246,9 @@ func TestSQLiteSaveStateErrors(t *testing.T) { }) } -func TestSQLiteLoadStateErrors(t *testing.T) { +func TestError_SQLiteLoadState_Bad(t *testing.T) { t.Run("loadState fails when requests table is dropped", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "req-err.db") + dbPath := testPath(t.TempDir(), "req-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -271,7 +270,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) { }) t.Run("loadState fails when tokens table is dropped", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "tok-err.db") + dbPath := testPath(t.TempDir(), "tok-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -293,7 +292,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) { }) t.Run("loadState fails when daily table is dropped", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "daily-load-err.db") + dbPath := testPath(t.TempDir(), "daily-load-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -314,9 +313,9 @@ func TestSQLiteLoadStateErrors(t *testing.T) { }) } -func TestSQLiteSaveQuotasExecError(t *testing.T) { +func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) { t.Run("saveQuotas fails with renamed column at prepare", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "quota-exec-err.db") + dbPath := testPath(t.TempDir(), "quota-exec-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -332,7 +331,7 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) { }) t.Run("saveQuotas fails at exec via trigger", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "quota-trigger.db") + dbPath := testPath(t.TempDir(), "quota-trigger.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -350,9 +349,9 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) { }) } -func TestSQLiteSaveStateExecErrors(t *testing.T) { +func TestError_SQLiteSaveStateExec_Bad(t *testing.T) { t.Run("request insert exec fails via trigger", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "trigger-req.db") + dbPath := testPath(t.TempDir(), "trigger-req.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -375,7 +374,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) { }) t.Run("token insert exec fails via trigger", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "trigger-tok.db") + dbPath := testPath(t.TempDir(), "trigger-tok.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -398,7 +397,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) { }) t.Run("daily insert exec fails via trigger", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "trigger-day.db") + dbPath := testPath(t.TempDir(), "trigger-day.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -420,9 +419,9 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) { }) } -func TestSQLiteLoadQuotasScanError(t *testing.T) { +func TestError_SQLiteLoadQuotasScan_Bad(t *testing.T) { t.Run("loadQuotas fails with renamed column", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "quota-scan-err.db") + dbPath := testPath(t.TempDir(), "quota-scan-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -441,26 +440,29 @@ func TestSQLiteLoadQuotasScanError(t *testing.T) { }) } -func TestNewSQLiteStoreInReadOnlyDir(t *testing.T) { - if os.Getuid() == 0 { +func TestError_NewSQLiteStoreInReadOnlyDir_Bad(t *testing.T) { + if isRootUser() { t.Skip("chmod restrictions do not apply to root") } t.Run("fails when parent directory is read-only", func(t *testing.T) { tmpDir := t.TempDir() - readonlyDir := filepath.Join(tmpDir, "readonly") - require.NoError(t, os.MkdirAll(readonlyDir, 0555)) - defer os.Chmod(readonlyDir, 0755) + readonlyDir := testPath(tmpDir, "readonly") + ensureTestDir(t, readonlyDir) + setPathMode(t, readonlyDir, 0o555) + defer func() { + _ = syscall.Chmod(readonlyDir, 0o755) + }() - dbPath := filepath.Join(readonlyDir, "test.db") + dbPath := testPath(readonlyDir, "test.db") _, err := newSQLiteStore(dbPath) assert.Error(t, err, "should fail when directory is read-only") }) } -func TestSQLiteCreateSchemaError(t *testing.T) { +func TestError_SQLiteCreateSchema_Bad(t *testing.T) { t.Run("createSchema fails on closed DB", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "schema-err.db") + dbPath := testPath(t.TempDir(), "schema-err.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) @@ -473,9 +475,9 @@ func TestSQLiteCreateSchemaError(t *testing.T) { }) } -func TestSQLiteLoadStateScanErrors(t *testing.T) { +func TestError_SQLiteLoadStateScan_Bad(t *testing.T) { t.Run("scan daily fails with NULL values", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "scan-daily.db") + dbPath := testPath(t.TempDir(), "scan-daily.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -495,7 +497,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) { }) t.Run("scan requests fails with NULL ts", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "scan-req.db") + dbPath := testPath(t.TempDir(), "scan-req.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -520,7 +522,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) { }) t.Run("scan tokens fails with NULL values", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "scan-tok.db") + dbPath := testPath(t.TempDir(), "scan-tok.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -545,9 +547,9 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) { }) } -func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) { +func TestError_SQLiteLoadQuotasScanWithBadSchema_Bad(t *testing.T) { t.Run("scan fails with NULL quota values", func(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "scan-quota.db") + dbPath := testPath(t.TempDir(), "scan-quota.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -566,11 +568,11 @@ func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) { }) } -func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) { +func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) { t.Run("saveQuotas failure during migration via trigger", func(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "with-quotas.yaml") - sqlitePath := filepath.Join(tmpDir, "migrate-quota-err.db") + yamlPath := testPath(tmpDir, "with-quotas.yaml") + sqlitePath := testPath(tmpDir, "migrate-quota-err.db") // Write a YAML file with quotas. yamlData := `quotas: @@ -579,7 +581,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) { max_tpm: 100 max_rpd: 50 ` - require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644)) + writeTestFile(t, yamlPath, yamlData) // Pre-create DB with a trigger that aborts quota inserts. store, err := newSQLiteStore(sqlitePath) @@ -596,8 +598,8 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) { t.Run("saveState failure during migration via trigger", func(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "with-state.yaml") - sqlitePath := filepath.Join(tmpDir, "migrate-state-err.db") + yamlPath := testPath(tmpDir, "with-state.yaml") + sqlitePath := testPath(tmpDir, "migrate-state-err.db") // Write YAML with state. yamlData := `state: @@ -607,7 +609,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) { day_start: 2026-01-01T00:00:00Z day_count: 1 ` - require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644)) + writeTestFile(t, yamlPath, yamlData) // Pre-create DB with a trigger that aborts daily inserts. store, err := newSQLiteStore(sqlitePath) @@ -622,13 +624,13 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) { }) } -func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) { +func TestError_MigrateYAMLToSQLiteNilQuotasAndState_Good(t *testing.T) { t.Run("YAML with empty quotas and state migrates cleanly", func(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "empty.yaml") - require.NoError(t, os.WriteFile(yamlPath, []byte("{}"), 0644)) + yamlPath := testPath(tmpDir, "empty.yaml") + writeTestFile(t, yamlPath, "{}") - sqlitePath := filepath.Join(tmpDir, "empty.db") + sqlitePath := testPath(tmpDir, "empty.db") require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath)) store, err := newSQLiteStore(sqlitePath) @@ -645,30 +647,18 @@ func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) { }) } -func TestNewWithConfigUserHomeDirError(t *testing.T) { - // Unset HOME to trigger os.UserHomeDir() error. - home := os.Getenv("HOME") - os.Unsetenv("HOME") - // Also unset fallback env vars that UserHomeDir checks. - plan9Home := os.Getenv("home") - os.Unsetenv("home") - userProfile := os.Getenv("USERPROFILE") - os.Unsetenv("USERPROFILE") - defer func() { - os.Setenv("HOME", home) - if plan9Home != "" { - os.Setenv("home", plan9Home) - } - if userProfile != "" { - os.Setenv("USERPROFILE", userProfile) - } - }() +func TestError_NewWithConfigUserHomeDir_Bad(t *testing.T) { + // Clear all supported home env vars so defaultStatePath cannot resolve a home directory. + t.Setenv("CORE_HOME", "") + t.Setenv("HOME", "") + t.Setenv("home", "") + t.Setenv("USERPROFILE", "") _, err := NewWithConfig(Config{}) assert.Error(t, err, "should fail when HOME is unset") } -func TestPersistMarshalError(t *testing.T) { +func TestError_PersistMarshal_Good(t *testing.T) { // yaml.Marshal on a struct with map[string]ModelQuota and map[string]*UsageStats // should not fail in practice. We test the error path by using a type that // yaml.Marshal cannot handle: a channel. @@ -680,20 +670,20 @@ func TestPersistMarshalError(t *testing.T) { assert.NoError(t, rl.Persist(), "valid persist should succeed") } -func TestMigrateErrorsExtended(t *testing.T) { +func TestError_MigrateErrorsExtended_Bad(t *testing.T) { t.Run("unmarshal failure", func(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "bad.yaml") - require.NoError(t, os.WriteFile(path, []byte("invalid: yaml: ["), 0644)) - err := MigrateYAMLToSQLite(path, filepath.Join(tmpDir, "out.db")) + path := testPath(tmpDir, "bad.yaml") + writeTestFile(t, path, "invalid: yaml: [") + err := MigrateYAMLToSQLite(path, testPath(tmpDir, "out.db")) assert.Error(t, err) assert.Contains(t, err.Error(), "ratelimit.MigrateYAMLToSQLite: unmarshal") }) t.Run("sqlite open failure", func(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "ok.yaml") - require.NoError(t, os.WriteFile(yamlPath, []byte("quotas: {}"), 0644)) + yamlPath := testPath(tmpDir, "ok.yaml") + writeTestFile(t, yamlPath, "quotas: {}") // Use an invalid sqlite path (dir where file should be) err := MigrateYAMLToSQLite(yamlPath, "/dev/null/not-a-db") assert.Error(t, err) diff --git a/iter_test.go b/iter_test.go index f353e16..cec378d 100644 --- a/iter_test.go +++ b/iter_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestIterators(t *testing.T) { +func TestIter_Iterators_Good(t *testing.T) { rl, err := NewWithConfig(Config{ Quotas: map[string]ModelQuota{ "model-c": {MaxRPM: 10}, @@ -77,7 +77,7 @@ func TestIterators(t *testing.T) { }) } -func TestIterEarlyBreak(t *testing.T) { +func TestIter_IterEarlyBreak_Good(t *testing.T) { rl, err := NewWithConfig(Config{ Quotas: map[string]ModelQuota{ "model-a": {MaxRPM: 10}, @@ -110,7 +110,7 @@ func TestIterEarlyBreak(t *testing.T) { }) } -func TestCountTokensFull(t *testing.T) { +func TestIter_CountTokensFull_Ugly(t *testing.T) { t.Run("empty model is rejected", func(t *testing.T) { _, err := CountTokens(context.Background(), "key", "", "text") assert.Error(t, err) diff --git a/ratelimit.go b/ratelimit.go index 35d4c3b..cc593cf 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -3,11 +3,11 @@ package ratelimit import ( "context" "io" + "io/fs" "iter" "maps" "net/http" "net/url" - "os" "slices" "sync" "time" @@ -17,6 +17,8 @@ import ( ) // Provider identifies an LLM provider for quota profiles. +// +// provider := ProviderOpenAI type Provider string const ( @@ -41,6 +43,8 @@ const ( ) // ModelQuota defines the rate limits for a specific model. +// +// quota := ModelQuota{MaxRPM: 60, MaxTPM: 90000, MaxRPD: 1000} type ModelQuota struct { MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited) MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited) @@ -48,12 +52,18 @@ type ModelQuota struct { } // ProviderProfile bundles model quotas for a provider. +// +// profile := ProviderProfile{Provider: ProviderGemini, Models: DefaultProfiles()[ProviderGemini].Models} type ProviderProfile struct { - Provider Provider `yaml:"provider"` - Models map[string]ModelQuota `yaml:"models"` + // Provider identifies the provider that owns the profile. + Provider Provider `yaml:"provider"` + // Models maps model names to quotas. + Models map[string]ModelQuota `yaml:"models"` } // Config controls RateLimiter initialisation. +// +// cfg := Config{Providers: []Provider{ProviderGemini}, FilePath: "/tmp/ratelimits.yaml"} type Config struct { // FilePath overrides the default state file location. // If empty, defaults to ~/.core/ratelimits.yaml. @@ -73,23 +83,35 @@ type Config struct { } // TokenEntry records a token usage event. +// +// entry := TokenEntry{Time: time.Now(), Count: 512} type TokenEntry struct { Time time.Time `yaml:"time"` Count int `yaml:"count"` } // UsageStats tracks usage history for a model. +// +// stats := UsageStats{DayStart: time.Now(), DayCount: 1} type UsageStats struct { Requests []time.Time `yaml:"requests"` // Sliding window (1m) Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m) - DayStart time.Time `yaml:"day_start"` - DayCount int `yaml:"day_count"` + // DayStart is the start of the rolling 24-hour window. + DayStart time.Time `yaml:"day_start"` + // DayCount is the number of requests recorded in the rolling 24-hour window. + DayCount int `yaml:"day_count"` } // RateLimiter manages rate limits across multiple models. +// +// rl, err := New() +// if err != nil { /* handle error */ } +// defer rl.Close() type RateLimiter struct { - mu sync.RWMutex - Quotas map[string]ModelQuota `yaml:"quotas"` + mu sync.RWMutex + // Quotas holds the configured per-model limits. + Quotas map[string]ModelQuota `yaml:"quotas"` + // State holds per-model usage windows. State map[string]*UsageStats `yaml:"state"` filePath string sqlite *sqliteStore // non-nil when backend is "sqlite" @@ -97,6 +119,9 @@ type RateLimiter struct { // DefaultProfiles returns pre-configured quota profiles for each provider. // Values are based on published rate limits as of Feb 2026. +// +// profiles := DefaultProfiles() +// openAI := profiles[ProviderOpenAI] func DefaultProfiles() map[Provider]ProviderProfile { return map[Provider]ProviderProfile{ ProviderGemini: { @@ -140,6 +165,8 @@ func DefaultProfiles() map[Provider]ProviderProfile { // New creates a new RateLimiter with Gemini defaults. // This preserves backward compatibility -- existing callers are unaffected. +// +// rl, err := New() func New() (*RateLimiter, error) { return NewWithConfig(Config{ Providers: []Provider{ProviderGemini}, @@ -148,6 +175,8 @@ func New() (*RateLimiter, error) { // NewWithConfig creates a RateLimiter from explicit configuration. // If no providers or quotas are specified, Gemini defaults are used. +// +// rl, err := NewWithConfig(Config{Providers: []Provider{ProviderAnthropic}}) func NewWithConfig(cfg Config) (*RateLimiter, error) { backend, err := normaliseBackend(cfg.Backend) if err != nil { @@ -177,6 +206,8 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) { } // SetQuota sets or updates the quota for a specific model at runtime. +// +// rl.SetQuota("gpt-4o-mini", ModelQuota{MaxRPM: 60, MaxTPM: 200000}) func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) { rl.mu.Lock() defer rl.mu.Unlock() @@ -185,6 +216,8 @@ func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) { // AddProvider loads all default quotas for a provider. // Existing quotas for models in the profile are overwritten. +// +// rl.AddProvider(ProviderOpenAI) func (rl *RateLimiter) AddProvider(provider Provider) { rl.mu.Lock() defer rl.mu.Unlock() @@ -196,6 +229,8 @@ func (rl *RateLimiter) AddProvider(provider Provider) { } // Load reads the state from disk (YAML) or database (SQLite). +// +// if err := rl.Load(); err != nil { /* handle error */ } func (rl *RateLimiter) Load() error { rl.mu.Lock() defer rl.mu.Unlock() @@ -205,7 +240,7 @@ func (rl *RateLimiter) Load() error { } content, err := readLocalFile(rl.filePath) - if os.IsNotExist(err) { + if core.Is(err, fs.ErrNotExist) { return nil } if err != nil { @@ -238,6 +273,8 @@ func (rl *RateLimiter) loadSQLite() error { // Persist writes a snapshot of the state to disk (YAML) or database (SQLite). // It clones the state under a lock and performs I/O without blocking other callers. +// +// if err := rl.Persist(); err != nil { /* handle error */ } func (rl *RateLimiter) Persist() error { rl.mu.Lock() quotas := maps.Clone(rl.Quotas) @@ -328,6 +365,9 @@ func (rl *RateLimiter) prune(model string) { // BackgroundPrune starts a goroutine that periodically prunes all model states. // It returns a function to stop the pruner. +// +// stop := rl.BackgroundPrune(30 * time.Second) +// defer stop() func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { if interval <= 0 { return func() {} @@ -354,6 +394,8 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { } // CanSend checks if a request can be sent without violating limits. +// +// ok := rl.CanSend("gemini-3-pro-preview", 1200) func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { if estimatedTokens < 0 { return false @@ -401,6 +443,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { } // RecordUsage records a successful API call. +// +// rl.RecordUsage("gemini-3-pro-preview", 900, 300) func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) { rl.mu.Lock() defer rl.mu.Unlock() @@ -420,6 +464,8 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) } // WaitForCapacity blocks until capacity is available or context is cancelled. +// +// err := rl.WaitForCapacity(ctx, "gemini-3-pro-preview", 1200) func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { if tokens < 0 { return core.E("ratelimit.WaitForCapacity", "negative tokens", nil) @@ -443,6 +489,8 @@ func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens } // Reset clears stats for a model (or all if model is empty). +// +// rl.Reset("gemini-3-pro-preview") func (rl *RateLimiter) Reset(model string) { rl.mu.Lock() defer rl.mu.Unlock() @@ -455,17 +503,28 @@ func (rl *RateLimiter) Reset(model string) { } // ModelStats represents a snapshot of usage. +// +// stats := rl.Stats("gemini-3-pro-preview") type ModelStats struct { - RPM int - MaxRPM int - TPM int - MaxTPM int - RPD int - MaxRPD int + // RPM is the current requests-per-minute usage in the sliding window. + RPM int + // MaxRPM is the configured requests-per-minute limit. + MaxRPM int + // TPM is the current tokens-per-minute usage in the sliding window. + TPM int + // MaxTPM is the configured tokens-per-minute limit. + MaxTPM int + // RPD is the current requests-per-day usage in the rolling 24-hour window. + RPD int + // MaxRPD is the configured requests-per-day limit. + MaxRPD int + // DayStart is the start of the current rolling 24-hour window. DayStart time.Time } // Models returns a sorted iterator over all model names tracked by the limiter. +// +// for model := range rl.Models() { println(model) } func (rl *RateLimiter) Models() iter.Seq[string] { rl.mu.RLock() defer rl.mu.RUnlock() @@ -482,6 +541,8 @@ func (rl *RateLimiter) Models() iter.Seq[string] { } // Iter returns a sorted iterator over all model names and their current stats. +// +// for model, stats := range rl.Iter() { _ = stats; println(model) } func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] { return func(yield func(string, ModelStats) bool) { stats := rl.AllStats() @@ -494,6 +555,8 @@ func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] { } // Stats returns current stats for a model. +// +// stats := rl.Stats("gemini-3-pro-preview") func (rl *RateLimiter) Stats(model string) ModelStats { rl.mu.Lock() defer rl.mu.Unlock() @@ -521,6 +584,8 @@ func (rl *RateLimiter) Stats(model string) ModelStats { } // AllStats returns stats for all tracked models. +// +// all := rl.AllStats() func (rl *RateLimiter) AllStats() map[string]ModelStats { rl.mu.Lock() defer rl.mu.Unlock() @@ -559,6 +624,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { // NewWithSQLite creates a SQLite-backed RateLimiter with Gemini defaults. // The database is created at dbPath if it does not exist. Use Close() to // release the database connection when finished. +// +// rl, err := NewWithSQLite("/tmp/ratelimits.db") func NewWithSQLite(dbPath string) (*RateLimiter, error) { return NewWithSQLiteConfig(dbPath, Config{ Providers: []Provider{ProviderGemini}, @@ -568,6 +635,8 @@ func NewWithSQLite(dbPath string) (*RateLimiter, error) { // NewWithSQLiteConfig creates a SQLite-backed RateLimiter with custom config. // The Backend field in cfg is ignored (always "sqlite"). Use Close() to // release the database connection when finished. +// +// rl, err := NewWithSQLiteConfig("/tmp/ratelimits.db", Config{Providers: []Provider{ProviderOpenAI}}) func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) { store, err := newSQLiteStore(dbPath) if err != nil { @@ -582,6 +651,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) { // Close releases resources held by the RateLimiter. For YAML-backed // limiters this is a no-op. For SQLite-backed limiters it closes the // database connection. +// +// defer rl.Close() func (rl *RateLimiter) Close() error { if rl.sqlite != nil { return rl.sqlite.close() @@ -592,6 +663,8 @@ func (rl *RateLimiter) Close() error { // MigrateYAMLToSQLite reads state from a YAML file and writes it to a new // SQLite database. Both quotas and usage state are migrated. The SQLite // database is created if it does not exist. +// +// err := MigrateYAMLToSQLite("ratelimits.yaml", "ratelimits.db") func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { // Load from YAML. content, err := readLocalFile(yamlPath) @@ -618,6 +691,8 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { } // CountTokens calls the Google API to count tokens for a prompt. +// +// tokens, err := CountTokens(ctx, apiKey, "gemini-3-pro-preview", prompt) func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text) } @@ -722,9 +797,9 @@ func normaliseBackend(backend string) (string, error) { } func defaultStatePath(backend string) (string, error) { - home, err := os.UserHomeDir() - if err != nil { - return "", err + home := currentHomeDir() + if home == "" { + return "", core.E("ratelimit.defaultStatePath", "home dir unavailable", nil) } fileName := defaultYAMLStateFile @@ -735,6 +810,15 @@ func defaultStatePath(backend string) (string, error) { return core.Path(home, defaultStateDirName, fileName), nil } +func currentHomeDir() string { + for _, key := range []string{"CORE_HOME", "HOME", "home", "USERPROFILE"} { + if value := core.Trim(core.Env(key)); value != "" { + return value + } + } + return "" +} + func safeTokenSum(a, b int) int { return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}}) } @@ -762,7 +846,7 @@ func safeTokenTotal(tokens []TokenEntry) int { func countTokensURL(baseURL, model string) (string, error) { if core.Trim(model) == "" { - return "", core.NewError("empty model") + return "", core.E("ratelimit.countTokensURL", "empty model", nil) } parsed, err := url.Parse(baseURL) @@ -770,7 +854,7 @@ func countTokensURL(baseURL, model string) (string, error) { return "", err } if parsed.Scheme == "" || parsed.Host == "" { - return "", core.NewError("invalid base url") + return "", core.E("ratelimit.countTokensURL", "invalid base url", nil) } return core.Concat(core.TrimSuffix(parsed.String(), "/"), "/v1beta/models/", url.PathEscape(model), ":countTokens"), nil @@ -803,7 +887,7 @@ func readLocalFile(path string) (string, error) { content, ok := result.Value.(string) if !ok { - return "", core.NewError("read returned non-string") + return "", core.E("ratelimit.readLocalFile", "read returned non-string", nil) } return content, nil } @@ -828,5 +912,5 @@ func resultError(result core.Result) error { if result.Value == nil { return nil } - return core.NewError(core.Sprint(result.Value)) + return core.E("ratelimit.resultError", core.Sprint(result.Value), nil) } diff --git a/ratelimit_test.go b/ratelimit_test.go index 73aac70..7c2471b 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -2,28 +2,92 @@ package ratelimit import ( "context" - "encoding/json" - "fmt" "io" "net/http" "net/http/httptest" - "os" - "path/filepath" - "strings" "sync" + "syscall" "testing" "time" + core "dappco.re/go/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func testPath(parts ...string) string { + return core.Path(parts...) +} + +func pathExists(path string) bool { + var fs core.Fs + return fs.Exists(path) +} + +func writeTestFile(tb testing.TB, path, content string) { + tb.Helper() + require.NoError(tb, writeLocalFile(path, content)) +} + +func ensureTestDir(tb testing.TB, path string) { + tb.Helper() + require.NoError(tb, ensureDir(path)) +} + +func setPathMode(tb testing.TB, path string, mode uint32) { + tb.Helper() + require.NoError(tb, syscall.Chmod(path, mode)) +} + +func overwriteTestFile(tb testing.TB, path, content string) { + tb.Helper() + + var fs core.Fs + writer := fs.Create(path) + require.NoError(tb, resultError(writer)) + require.NoError(tb, resultError(core.WriteAll(writer.Value, content))) +} + +func isRootUser() bool { + return syscall.Geteuid() == 0 +} + +func repeatString(part string, count int) string { + builder := core.NewBuilder() + for i := 0; i < count; i++ { + builder.WriteString(part) + } + return builder.String() +} + +func substringCount(s, substr string) int { + if substr == "" { + return 0 + } + return len(core.Split(s, substr)) - 1 +} + +func decodeJSONBody(tb testing.TB, r io.Reader, target any) { + tb.Helper() + + data, err := io.ReadAll(r) + require.NoError(tb, err) + require.NoError(tb, resultError(core.JSONUnmarshal(data, target))) +} + +func writeJSONBody(tb testing.TB, w io.Writer, value any) { + tb.Helper() + + _, err := io.WriteString(w, core.JSONMarshalString(value)) + require.NoError(tb, err) +} + // newTestLimiter returns a RateLimiter with file path set to a temp directory. func newTestLimiter(t *testing.T) *RateLimiter { t.Helper() rl, err := New() require.NoError(t, err) - rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml") + rl.filePath = testPath(t.TempDir(), "ratelimits.yaml") return rl } @@ -41,7 +105,7 @@ func (errReader) Read([]byte) (int, error) { // --- Phase 0: CanSend boundary conditions --- -func TestCanSend(t *testing.T) { +func TestRatelimit_CanSend_Good(t *testing.T) { t.Run("fresh state allows send", func(t *testing.T) { rl := newTestLimiter(t) model := "test-model" @@ -189,7 +253,7 @@ func TestCanSend(t *testing.T) { // --- Phase 0: Sliding window / prune tests --- -func TestPrune(t *testing.T) { +func TestRatelimit_Prune_Good(t *testing.T) { t.Run("removes old entries", func(t *testing.T) { rl := newTestLimiter(t) model := "test-prune" @@ -304,7 +368,7 @@ func TestPrune(t *testing.T) { // --- Phase 0: RecordUsage --- -func TestRecordUsage(t *testing.T) { +func TestRatelimit_RecordUsage_Good(t *testing.T) { t.Run("records into fresh state", func(t *testing.T) { rl := newTestLimiter(t) model := "record-fresh" @@ -375,7 +439,7 @@ func TestRecordUsage(t *testing.T) { // --- Phase 0: Reset --- -func TestReset(t *testing.T) { +func TestRatelimit_Reset_Good(t *testing.T) { t.Run("reset single model", func(t *testing.T) { rl := newTestLimiter(t) rl.RecordUsage("model-a", 10, 10) @@ -409,7 +473,7 @@ func TestReset(t *testing.T) { // --- Phase 0: WaitForCapacity --- -func TestWaitForCapacity(t *testing.T) { +func TestRatelimit_WaitForCapacity_Good(t *testing.T) { t.Run("context cancelled returns error", func(t *testing.T) { rl := newTestLimiter(t) model := "wait-cancel" @@ -467,7 +531,7 @@ func TestWaitForCapacity(t *testing.T) { }) } -func TestNilUsageStats(t *testing.T) { +func TestRatelimit_NilUsageStats_Ugly(t *testing.T) { t.Run("CanSend replaces nil state without panicking", func(t *testing.T) { rl := newTestLimiter(t) model := "nil-cansend" @@ -514,7 +578,7 @@ func TestNilUsageStats(t *testing.T) { // --- Phase 0: Stats --- -func TestStats(t *testing.T) { +func TestRatelimit_Stats_Good(t *testing.T) { t.Run("returns stats for known model with usage", func(t *testing.T) { rl := newTestLimiter(t) model := "stats-test" @@ -554,7 +618,7 @@ func TestStats(t *testing.T) { // --- Phase 0: AllStats --- -func TestAllStats(t *testing.T) { +func TestRatelimit_AllStats_Good(t *testing.T) { t.Run("includes all default quotas plus state-only models", func(t *testing.T) { rl := newTestLimiter(t) rl.RecordUsage("gemini-3-pro-preview", 1000, 500) @@ -612,10 +676,10 @@ func TestAllStats(t *testing.T) { // --- Phase 0: Persist and Load --- -func TestPersistAndLoad(t *testing.T) { +func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) { t.Run("round-trip preserves state", func(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "ratelimits.yaml") + path := testPath(tmpDir, "ratelimits.yaml") rl1, err := New() require.NoError(t, err) @@ -638,7 +702,7 @@ func TestPersistAndLoad(t *testing.T) { t.Run("load from non-existent file is not an error", func(t *testing.T) { rl := newTestLimiter(t) - rl.filePath = filepath.Join(t.TempDir(), "does-not-exist.yaml") + rl.filePath = testPath(t.TempDir(), "does-not-exist.yaml") err := rl.Load() assert.NoError(t, err, "loading non-existent file should not error") @@ -646,8 +710,8 @@ func TestPersistAndLoad(t *testing.T) { t.Run("load from corrupt YAML returns error", func(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "corrupt.yaml") - require.NoError(t, os.WriteFile(path, []byte("{{{{invalid yaml!!!!"), 0644)) + path := testPath(tmpDir, "corrupt.yaml") + writeTestFile(t, path, "{{{{invalid yaml!!!!") rl := newTestLimiter(t) rl.filePath = path @@ -657,13 +721,13 @@ func TestPersistAndLoad(t *testing.T) { }) t.Run("load from unreadable file returns error", func(t *testing.T) { - if os.Getuid() == 0 { + if isRootUser() { t.Skip("chmod 000 does not restrict root") } tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "unreadable.yaml") - require.NoError(t, os.WriteFile(path, []byte("quotas: {}"), 0644)) - require.NoError(t, os.Chmod(path, 0000)) + path := testPath(tmpDir, "unreadable.yaml") + writeTestFile(t, path, "quotas: {}") + setPathMode(t, path, 0o000) rl := newTestLimiter(t) rl.filePath = path @@ -672,12 +736,12 @@ func TestPersistAndLoad(t *testing.T) { assert.Error(t, err, "unreadable file should produce an error") // Clean up permissions for temp dir cleanup - _ = os.Chmod(path, 0644) + _ = syscall.Chmod(path, 0o644) }) t.Run("persist to nested non-existent directory creates it", func(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "nested", "deep", "ratelimits.yaml") + path := testPath(tmpDir, "nested", "deep", "ratelimits.yaml") rl := newTestLimiter(t) rl.filePath = path @@ -686,32 +750,32 @@ func TestPersistAndLoad(t *testing.T) { err := rl.Persist() assert.NoError(t, err, "should create nested directories") - _, statErr := os.Stat(path) - assert.NoError(t, statErr, "file should exist") + assert.True(t, pathExists(path), "file should exist") }) t.Run("persist to unwritable directory returns error", func(t *testing.T) { - if os.Getuid() == 0 { + if isRootUser() { t.Skip("chmod 0555 does not restrict root") } tmpDir := t.TempDir() - unwritable := filepath.Join(tmpDir, "readonly") - require.NoError(t, os.MkdirAll(unwritable, 0555)) + unwritable := testPath(tmpDir, "readonly") + ensureTestDir(t, unwritable) + setPathMode(t, unwritable, 0o555) rl := newTestLimiter(t) - rl.filePath = filepath.Join(unwritable, "sub", "ratelimits.yaml") + rl.filePath = testPath(unwritable, "sub", "ratelimits.yaml") err := rl.Persist() assert.Error(t, err, "should fail when directory is unwritable") // Clean up - _ = os.Chmod(unwritable, 0755) + _ = syscall.Chmod(unwritable, 0o755) }) } // --- Phase 0: Default quotas --- -func TestDefaultQuotas(t *testing.T) { +func TestRatelimit_DefaultQuotas_Good(t *testing.T) { rl := newTestLimiter(t) tests := []struct { @@ -740,7 +804,7 @@ func TestDefaultQuotas(t *testing.T) { // --- Phase 0: Concurrent access (race test) --- -func TestConcurrentAccess(t *testing.T) { +func TestRatelimit_ConcurrentAccess_Good(t *testing.T) { rl := newTestLimiter(t) model := "concurrent-test" rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 10000} @@ -766,7 +830,7 @@ func TestConcurrentAccess(t *testing.T) { assert.Equal(t, expected, stats.RPD, "all recordings should be counted") } -func TestConcurrentResetAndRecord(t *testing.T) { +func TestRatelimit_ConcurrentResetAndRecord_Ugly(t *testing.T) { rl := newTestLimiter(t) model := "concurrent-reset" rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000} @@ -804,7 +868,7 @@ func TestConcurrentResetAndRecord(t *testing.T) { // No assertion needed -- if we get here without -race flagging, mutex is sound } -func TestBackgroundPrune(t *testing.T) { +func TestRatelimit_BackgroundPrune_Good(t *testing.T) { rl := newTestLimiter(t) model := "prune-me" rl.Quotas[model] = ModelQuota{MaxRPM: 100} @@ -843,7 +907,7 @@ func TestBackgroundPrune(t *testing.T) { // --- Phase 0: CountTokens (with mock HTTP server) --- -func TestCountTokens(t *testing.T) { +func TestRatelimit_CountTokens_Ugly(t *testing.T) { t.Run("successful token count", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, http.MethodPost, r.Method) @@ -858,13 +922,13 @@ func TestCountTokens(t *testing.T) { } `json:"parts"` } `json:"contents"` } - require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + decodeJSONBody(t, r.Body, &body) require.Len(t, body.Contents, 1) require.Len(t, body.Contents[0].Parts, 1) assert.Equal(t, "hello", body.Contents[0].Parts[0].Text) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})) + writeJSONBody(t, w, map[string]int{"totalTokens": 42}) })) defer server.Close() @@ -878,7 +942,7 @@ func TestCountTokens(t *testing.T) { assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath()) assert.Empty(t, r.URL.RawQuery) w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7})) + writeJSONBody(t, w, map[string]int{"totalTokens": 7}) })) defer server.Close() @@ -888,10 +952,10 @@ func TestCountTokens(t *testing.T) { }) t.Run("API error body is truncated", func(t *testing.T) { - largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256) + largeBody := repeatString("x", countTokensErrorBodyLimit+256) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - _, err := fmt.Fprint(w, largeBody) + _, err := io.WriteString(w, largeBody) require.NoError(t, err) })) defer server.Close() @@ -899,7 +963,7 @@ func TestCountTokens(t *testing.T) { _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello") require.Error(t, err) assert.Contains(t, err.Error(), "api error status 401") - assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded") + assert.True(t, substringCount(err.Error(), "x") < len(largeBody), "error body should be bounded") assert.Contains(t, err.Error(), "...") }) @@ -953,7 +1017,7 @@ func TestCountTokens(t *testing.T) { t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11})) + writeJSONBody(t, w, map[string]int{"totalTokens": 11}) })) defer server.Close() @@ -969,8 +1033,8 @@ func TestCountTokens(t *testing.T) { }) } -func TestPersistSkipsNilState(t *testing.T) { - path := filepath.Join(t.TempDir(), "nil-state.yaml") +func TestRatelimit_PersistSkipsNilState_Good(t *testing.T) { + path := testPath(t.TempDir(), "nil-state.yaml") rl, err := New() require.NoError(t, err) @@ -986,7 +1050,7 @@ func TestPersistSkipsNilState(t *testing.T) { assert.NotContains(t, rl2.State, "nil-model") } -func TestTokenTotals(t *testing.T) { +func TestRatelimit_TokenTotals_Good(t *testing.T) { maxInt := int(^uint(0) >> 1) assert.Equal(t, 25, safeTokenSum(-100, 25)) @@ -1056,7 +1120,7 @@ func BenchmarkCanSendConcurrent(b *testing.B) { // --- Phase 1: Provider profiles and NewWithConfig --- -func TestDefaultProfiles(t *testing.T) { +func TestRatelimit_DefaultProfiles_Good(t *testing.T) { profiles := DefaultProfiles() t.Run("contains all four providers", func(t *testing.T) { @@ -1097,10 +1161,10 @@ func TestDefaultProfiles(t *testing.T) { }) } -func TestNewWithConfig(t *testing.T) { +func TestRatelimit_NewWithConfig_Ugly(t *testing.T) { t.Run("empty config defaults to Gemini", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), }) require.NoError(t, err) @@ -1110,7 +1174,7 @@ func TestNewWithConfig(t *testing.T) { t.Run("single provider loads only its models", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Providers: []Provider{ProviderOpenAI}, }) require.NoError(t, err) @@ -1124,7 +1188,7 @@ func TestNewWithConfig(t *testing.T) { t.Run("multiple providers merge models", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Providers: []Provider{ProviderGemini, ProviderAnthropic}, }) require.NoError(t, err) @@ -1140,7 +1204,7 @@ func TestNewWithConfig(t *testing.T) { t.Run("explicit quotas override provider defaults", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Providers: []Provider{ProviderGemini}, Quotas: map[string]ModelQuota{ "gemini-3-pro-preview": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777}, @@ -1156,7 +1220,7 @@ func TestNewWithConfig(t *testing.T) { t.Run("explicit quotas without providers", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Quotas: map[string]ModelQuota{ "my-custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, }, @@ -1169,7 +1233,7 @@ func TestNewWithConfig(t *testing.T) { }) t.Run("custom file path is respected", func(t *testing.T) { - customPath := filepath.Join(t.TempDir(), "custom", "limits.yaml") + customPath := testPath(t.TempDir(), "custom", "limits.yaml") rl, err := NewWithConfig(Config{ FilePath: customPath, Providers: []Provider{ProviderLocal}, @@ -1179,13 +1243,12 @@ func TestNewWithConfig(t *testing.T) { rl.RecordUsage("test", 1, 1) require.NoError(t, rl.Persist()) - _, statErr := os.Stat(customPath) - assert.NoError(t, statErr, "file should be created at custom path") + assert.True(t, pathExists(customPath), "file should be created at custom path") }) t.Run("unknown provider is silently skipped", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Providers: []Provider{"nonexistent-provider"}, }) require.NoError(t, err) @@ -1194,7 +1257,7 @@ func TestNewWithConfig(t *testing.T) { t.Run("local provider with custom quotas", func(t *testing.T) { rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "test.yaml"), + FilePath: testPath(t.TempDir(), "test.yaml"), Providers: []Provider{ProviderLocal}, Quotas: map[string]ModelQuota{ "llama-3.3-70b": {MaxRPM: 5, MaxTPM: 50000, MaxRPD: 0}, @@ -1224,11 +1287,11 @@ func TestNewWithConfig(t *testing.T) { rl, err := NewWithConfig(Config{}) require.NoError(t, err) - assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath) + assert.Equal(t, testPath(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath) }) } -func TestNewBackwardCompatibility(t *testing.T) { +func TestRatelimit_NewBackwardCompatibility_Good(t *testing.T) { // New() should produce the exact same result as before Phase 1 rl, err := New() require.NoError(t, err) @@ -1251,7 +1314,7 @@ func TestNewBackwardCompatibility(t *testing.T) { } } -func TestSetQuota(t *testing.T) { +func TestRatelimit_SetQuota_Good(t *testing.T) { t.Run("adds new model quota", func(t *testing.T) { rl := newTestLimiter(t) rl.SetQuota("custom-model", ModelQuota{MaxRPM: 42, MaxTPM: 9999, MaxRPD: 100}) @@ -1279,7 +1342,7 @@ func TestSetQuota(t *testing.T) { wg.Add(1) go func(n int) { defer wg.Done() - model := fmt.Sprintf("model-%d", n) + model := core.Sprintf("model-%d", n) rl.SetQuota(model, ModelQuota{MaxRPM: n, MaxTPM: n * 100, MaxRPD: n * 10}) }(i) } @@ -1289,7 +1352,7 @@ func TestSetQuota(t *testing.T) { }) } -func TestAddProvider(t *testing.T) { +func TestRatelimit_AddProvider_Good(t *testing.T) { t.Run("adds OpenAI models to existing limiter", func(t *testing.T) { rl := newTestLimiter(t) // starts with Gemini defaults geminiCount := len(rl.Quotas) @@ -1351,7 +1414,7 @@ func TestAddProvider(t *testing.T) { }) } -func TestProviderConstants(t *testing.T) { +func TestRatelimit_ProviderConstants_Good(t *testing.T) { // Verify the string values are stable (they may be used in YAML configs) assert.Equal(t, Provider("gemini"), ProviderGemini) assert.Equal(t, Provider("openai"), ProviderOpenAI) @@ -1361,7 +1424,7 @@ func TestProviderConstants(t *testing.T) { // --- Phase 0 addendum: Additional concurrent and multi-model race tests --- -func TestConcurrentMultipleModels(t *testing.T) { +func TestRatelimit_ConcurrentMultipleModels_Good(t *testing.T) { rl := newTestLimiter(t) models := []string{"model-a", "model-b", "model-c", "model-d", "model-e"} for _, m := range models { @@ -1391,9 +1454,9 @@ func TestConcurrentMultipleModels(t *testing.T) { } } -func TestConcurrentPersistAndLoad(t *testing.T) { +func TestRatelimit_ConcurrentPersistAndLoad_Ugly(t *testing.T) { tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "concurrent.yaml") + path := testPath(tmpDir, "concurrent.yaml") rl := newTestLimiter(t) rl.filePath = path @@ -1425,7 +1488,7 @@ func TestConcurrentPersistAndLoad(t *testing.T) { // No panics or data races = pass } -func TestConcurrentAllStatsAndRecordUsage(t *testing.T) { +func TestRatelimit_ConcurrentAllStatsAndRecordUsage_Good(t *testing.T) { rl := newTestLimiter(t) models := []string{"stats-a", "stats-b", "stats-c"} for _, m := range models { @@ -1456,7 +1519,7 @@ func TestConcurrentAllStatsAndRecordUsage(t *testing.T) { wg.Wait() } -func TestConcurrentWaitForCapacityAndRecordUsage(t *testing.T) { +func TestRatelimit_ConcurrentWaitForCapacityAndRecordUsage_Good(t *testing.T) { rl := newTestLimiter(t) model := "race-wait" rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 10000000, MaxRPD: 10000} @@ -1553,7 +1616,7 @@ func BenchmarkAllStats(b *testing.B) { func BenchmarkPersist(b *testing.B) { tmpDir := b.TempDir() - path := filepath.Join(tmpDir, "bench.yaml") + path := testPath(tmpDir, "bench.yaml") rl, _ := New() rl.filePath = path @@ -1574,10 +1637,10 @@ func BenchmarkPersist(b *testing.B) { } } -func TestEndToEndMultiProvider(t *testing.T) { +func TestRatelimit_EndToEndMultiProvider_Good(t *testing.T) { // Simulate a real-world scenario: limiter for both Gemini and Anthropic rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "multi.yaml"), + FilePath: testPath(t.TempDir(), "multi.yaml"), Providers: []Provider{ProviderGemini, ProviderAnthropic}, }) require.NoError(t, err) diff --git a/sqlite_test.go b/sqlite_test.go index 7026261..90dece9 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1,8 +1,6 @@ package ratelimit import ( - "os" - "path/filepath" "sync" "testing" "time" @@ -14,18 +12,17 @@ import ( // --- Phase 2: SQLite basic tests --- -func TestNewSQLiteStore_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "test.db") +func TestSQLite_NewSQLiteStore_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "test.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() // Verify the database file was created. - _, statErr := os.Stat(dbPath) - assert.NoError(t, statErr, "database file should exist") + assert.True(t, pathExists(dbPath), "database file should exist") } -func TestNewSQLiteStore_Bad(t *testing.T) { +func TestSQLite_NewSQLiteStore_Bad(t *testing.T) { t.Run("invalid path returns error", func(t *testing.T) { // Path inside a non-existent directory with no parent. _, err := newSQLiteStore("/nonexistent/deep/nested/dir/test.db") @@ -33,8 +30,8 @@ func TestNewSQLiteStore_Bad(t *testing.T) { }) } -func TestSQLiteQuotasRoundTrip_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "quotas.db") +func TestSQLite_QuotasRoundTrip_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "quotas.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -60,8 +57,8 @@ func TestSQLiteQuotasRoundTrip_Good(t *testing.T) { } } -func TestSQLiteQuotasUpsert_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "upsert.db") +func TestSQLite_QuotasUpsert_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "upsert.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -85,8 +82,8 @@ func TestSQLiteQuotasUpsert_Good(t *testing.T) { assert.Equal(t, 777, q.MaxRPD, "should have updated RPD") } -func TestSQLiteStateRoundTrip_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "state.db") +func TestSQLite_StateRoundTrip_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "state.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -144,8 +141,8 @@ func TestSQLiteStateRoundTrip_Good(t *testing.T) { } } -func TestSQLiteStateOverwrite_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "overwrite.db") +func TestSQLite_StateOverwrite_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "overwrite.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -182,8 +179,8 @@ func TestSQLiteStateOverwrite_Good(t *testing.T) { assert.Len(t, b.Requests, 1) } -func TestSQLiteEmptyState_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "empty.db") +func TestSQLite_EmptyState_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "empty.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -198,8 +195,8 @@ func TestSQLiteEmptyState_Good(t *testing.T) { assert.Empty(t, state, "should return empty state from fresh DB") } -func TestSQLiteClose_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "close.db") +func TestSQLite_Close_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "close.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) @@ -208,8 +205,8 @@ func TestSQLiteClose_Good(t *testing.T) { // --- Phase 2: SQLite integration tests --- -func TestNewWithSQLite_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "limiter.db") +func TestSQLite_NewWithSQLite_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "limiter.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) defer rl.Close() @@ -222,8 +219,8 @@ func TestNewWithSQLite_Good(t *testing.T) { assert.NotNil(t, rl.sqlite, "SQLite store should be initialised") } -func TestNewWithSQLiteConfig_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "config.db") +func TestSQLite_NewWithSQLiteConfig_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "config.db") rl, err := NewWithSQLiteConfig(dbPath, Config{ Providers: []Provider{ProviderAnthropic}, Quotas: map[string]ModelQuota{ @@ -243,8 +240,8 @@ func TestNewWithSQLiteConfig_Good(t *testing.T) { assert.False(t, hasGemini, "should not have Gemini models") } -func TestSQLitePersistAndLoad_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "persist.db") +func TestSQLite_PersistAndLoad_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "persist.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -272,8 +269,8 @@ func TestSQLitePersistAndLoad_Good(t *testing.T) { assert.Equal(t, 500, stats.MaxRPD) } -func TestSQLitePersistMultipleModels_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "multi.db") +func TestSQLite_PersistMultipleModels_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "multi.db") rl, err := NewWithSQLiteConfig(dbPath, Config{ Providers: []Provider{ProviderGemini, ProviderAnthropic}, }) @@ -302,8 +299,8 @@ func TestSQLitePersistMultipleModels_Good(t *testing.T) { assert.Equal(t, 400, claude.TPM) } -func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "record.db") +func TestSQLite_RecordUsageThenPersistReload_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "record.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) @@ -340,7 +337,7 @@ func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) { assert.Equal(t, 1000, stats2.TPM, "TPM should survive reload") } -func TestSQLiteClose_Good_NoOp(t *testing.T) { +func TestSQLite_CloseNoOp_Good(t *testing.T) { // Close on YAML-backed limiter is a no-op. rl := newTestLimiter(t) assert.NoError(t, rl.Close(), "Close on YAML limiter should be no-op") @@ -348,8 +345,8 @@ func TestSQLiteClose_Good_NoOp(t *testing.T) { // --- Phase 2: Concurrent SQLite --- -func TestSQLiteConcurrent_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "concurrent.db") +func TestSQLite_Concurrent_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "concurrent.db") rl, err := NewWithSQLite(dbPath) require.NoError(t, err) defer rl.Close() @@ -398,10 +395,10 @@ func TestSQLiteConcurrent_Good(t *testing.T) { // --- Phase 2: YAML backward compatibility --- -func TestYAMLBackwardCompat_Good(t *testing.T) { +func TestSQLite_YAMLBackwardCompat_Good(t *testing.T) { // Verify that the default YAML backend still works after SQLite additions. tmpDir := t.TempDir() - path := filepath.Join(tmpDir, "compat.yaml") + path := testPath(tmpDir, "compat.yaml") rl1, err := New() require.NoError(t, err) @@ -425,18 +422,18 @@ func TestYAMLBackwardCompat_Good(t *testing.T) { assert.Equal(t, 200, stats.TPM) } -func TestConfigBackendDefault_Good(t *testing.T) { +func TestSQLite_ConfigBackendDefault_Good(t *testing.T) { // Empty Backend string should default to YAML behaviour. rl, err := NewWithConfig(Config{ - FilePath: filepath.Join(t.TempDir(), "default.yaml"), + FilePath: testPath(t.TempDir(), "default.yaml"), }) require.NoError(t, err) assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)") } -func TestConfigBackendSQLite_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "config-backend.db") +func TestSQLite_ConfigBackendSQLite_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "config-backend.db") rl, err := NewWithConfig(Config{ Backend: backendSQLite, FilePath: dbPath, @@ -451,11 +448,10 @@ func TestConfigBackendSQLite_Good(t *testing.T) { rl.RecordUsage("backend-model", 10, 10) require.NoError(t, rl.Persist()) - _, statErr := os.Stat(dbPath) - assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path") + assert.True(t, pathExists(dbPath), "sqlite backend should persist to the configured DB path") } -func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) { +func TestSQLite_ConfigBackendSQLiteDefaultPath_Good(t *testing.T) { home := t.TempDir() t.Setenv("HOME", home) t.Setenv("USERPROFILE", "") @@ -470,16 +466,15 @@ func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) { require.NotNil(t, rl.sqlite) require.NoError(t, rl.Persist()) - _, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile)) - assert.NoError(t, statErr, "sqlite backend should use the default home DB path") + assert.True(t, pathExists(testPath(home, defaultStateDirName, defaultSQLiteStateFile)), "sqlite backend should use the default home DB path") } // --- Phase 2: MigrateYAMLToSQLite --- -func TestMigrateYAMLToSQLite_Good(t *testing.T) { +func TestSQLite_MigrateYAMLToSQLite_Good(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "state.yaml") - sqlitePath := filepath.Join(tmpDir, "migrated.db") + yamlPath := testPath(tmpDir, "state.yaml") + sqlitePath := testPath(tmpDir, "migrated.db") // Create a YAML-backed limiter with state. rl, err := New() @@ -515,26 +510,26 @@ func TestMigrateYAMLToSQLite_Good(t *testing.T) { assert.Equal(t, 2, stats.RPD, "should have 2 daily requests") } -func TestMigrateYAMLToSQLite_Bad(t *testing.T) { +func TestSQLite_MigrateYAMLToSQLite_Bad(t *testing.T) { t.Run("non-existent YAML file", func(t *testing.T) { - err := MigrateYAMLToSQLite("/nonexistent/state.yaml", filepath.Join(t.TempDir(), "out.db")) + err := MigrateYAMLToSQLite("/nonexistent/state.yaml", testPath(t.TempDir(), "out.db")) assert.Error(t, err, "should fail with non-existent YAML file") }) t.Run("corrupt YAML file", func(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "corrupt.yaml") - require.NoError(t, os.WriteFile(yamlPath, []byte("{{{{not yaml!"), 0644)) + yamlPath := testPath(tmpDir, "corrupt.yaml") + writeTestFile(t, yamlPath, "{{{{not yaml!") - err := MigrateYAMLToSQLite(yamlPath, filepath.Join(tmpDir, "out.db")) + err := MigrateYAMLToSQLite(yamlPath, testPath(tmpDir, "out.db")) assert.Error(t, err, "should fail with corrupt YAML") }) } -func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) { +func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "atomic.yaml") - sqlitePath := filepath.Join(tmpDir, "atomic.db") + yamlPath := testPath(tmpDir, "atomic.yaml") + sqlitePath := testPath(tmpDir, "atomic.db") now := time.Now().UTC() store, err := newSQLiteStore(sqlitePath) @@ -573,7 +568,7 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) { } data, err := yaml.Marshal(migrated) require.NoError(t, err) - require.NoError(t, os.WriteFile(yamlPath, data, 0o644)) + writeTestFile(t, yamlPath, string(data)) err = MigrateYAMLToSQLite(yamlPath, sqlitePath) require.Error(t, err) @@ -594,10 +589,10 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) { assert.NotContains(t, state, "new-model") } -func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { +func TestSQLite_MigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "full.yaml") - sqlitePath := filepath.Join(tmpDir, "full.db") + yamlPath := testPath(tmpDir, "full.yaml") + sqlitePath := testPath(tmpDir, "full.db") // Create a full YAML state with all Gemini models. rl, err := New() @@ -626,12 +621,12 @@ func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { // --- Phase 2: Corrupt DB recovery --- -func TestSQLiteCorruptDB_Ugly(t *testing.T) { +func TestSQLite_CorruptDB_Ugly(t *testing.T) { tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "corrupt.db") + dbPath := testPath(tmpDir, "corrupt.db") // Write garbage to the DB file. - require.NoError(t, os.WriteFile(dbPath, []byte("THIS IS NOT A SQLITE DATABASE"), 0644)) + writeTestFile(t, dbPath, "THIS IS NOT A SQLITE DATABASE") // Opening a corrupt DB may succeed (sqlite is lazy about validation), // but operations on it should fail gracefully. @@ -648,9 +643,9 @@ func TestSQLiteCorruptDB_Ugly(t *testing.T) { assert.Error(t, err, "loading from corrupt DB should return an error") } -func TestSQLiteTruncatedDB_Ugly(t *testing.T) { +func TestSQLite_TruncatedDB_Ugly(t *testing.T) { tmpDir := t.TempDir() - dbPath := filepath.Join(tmpDir, "truncated.db") + dbPath := testPath(tmpDir, "truncated.db") // Create a valid DB first. store, err := newSQLiteStore(dbPath) @@ -661,11 +656,7 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) { require.NoError(t, store.close()) // Truncate the file to simulate corruption. - f, err := os.OpenFile(dbPath, os.O_WRONLY|os.O_TRUNC, 0644) - require.NoError(t, err) - _, err = f.Write([]byte("TRUNC")) - require.NoError(t, err) - require.NoError(t, f.Close()) + overwriteTestFile(t, dbPath, "TRUNC") // Opening should either fail or operations should fail. store2, err := newSQLiteStore(dbPath) @@ -679,9 +670,9 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) { assert.Error(t, err, "loading from truncated DB should return an error") } -func TestSQLiteEmptyModelState_Good(t *testing.T) { +func TestSQLite_EmptyModelState_Good(t *testing.T) { // State with no requests or tokens but with a daily counter. - dbPath := filepath.Join(t.TempDir(), "empty-state.db") + dbPath := testPath(t.TempDir(), "empty-state.db") store, err := newSQLiteStore(dbPath) require.NoError(t, err) defer store.close() @@ -708,8 +699,8 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) { // --- Phase 2: End-to-end with persist cycle --- -func TestSQLiteEndToEnd_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "e2e.db") +func TestSQLite_EndToEnd_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "e2e.db") // Session 1: Create limiter, record usage, persist. rl1, err := NewWithSQLiteConfig(dbPath, Config{ @@ -752,8 +743,8 @@ func TestSQLiteEndToEnd_Good(t *testing.T) { assert.Equal(t, 5, custom.MaxRPM) } -func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "replace.db") +func TestSQLite_LoadReplacesPersistedSnapshot_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "replace.db") rl, err := NewWithSQLiteConfig(dbPath, Config{ Quotas: map[string]ModelQuota{ "model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, @@ -788,8 +779,8 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) { assert.Equal(t, 1, rl2.Stats("model-b").RPD) } -func TestSQLitePersistAtomic_Good(t *testing.T) { - dbPath := filepath.Join(t.TempDir(), "persist-atomic.db") +func TestSQLite_PersistAtomic_Good(t *testing.T) { + dbPath := testPath(t.TempDir(), "persist-atomic.db") rl, err := NewWithSQLiteConfig(dbPath, Config{ Quotas: map[string]ModelQuota{ "old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, @@ -827,7 +818,7 @@ func TestSQLitePersistAtomic_Good(t *testing.T) { // --- Phase 2: Benchmark --- func BenchmarkSQLitePersist(b *testing.B) { - dbPath := filepath.Join(b.TempDir(), "bench.db") + dbPath := testPath(b.TempDir(), "bench.db") rl, err := NewWithSQLite(dbPath) if err != nil { b.Fatal(err) @@ -852,7 +843,7 @@ func BenchmarkSQLitePersist(b *testing.B) { } func BenchmarkSQLiteLoad(b *testing.B) { - dbPath := filepath.Join(b.TempDir(), "bench-load.db") + dbPath := testPath(b.TempDir(), "bench-load.db") rl, err := NewWithSQLite(dbPath) if err != nil { b.Fatal(err) @@ -883,10 +874,10 @@ func BenchmarkSQLiteLoad(b *testing.B) { // TestMigrateYAMLToSQLiteWithFullState tests migration of a realistic YAML // file that contains the full serialised RateLimiter struct. -func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) { +func TestSQLite_MigrateYAMLToSQLiteWithFullState_Good(t *testing.T) { tmpDir := t.TempDir() - yamlPath := filepath.Join(tmpDir, "realistic.yaml") - sqlitePath := filepath.Join(tmpDir, "realistic.db") + yamlPath := testPath(tmpDir, "realistic.yaml") + sqlitePath := testPath(tmpDir, "realistic.db") now := time.Now() @@ -919,7 +910,7 @@ func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) { data, err := yaml.Marshal(rl) require.NoError(t, err) - require.NoError(t, os.WriteFile(yamlPath, data, 0644)) + writeTestFile(t, yamlPath, string(data)) // Migrate. require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))