From 5df6ce127bb32b8591188e24fcf4826b700131a9 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 09:30:15 +0000 Subject: [PATCH] fix(sqlite): preserve defaults before first persist Co-Authored-By: Virgil --- error_test.go | 35 +++++++++++++++ ratelimit.go | 8 ++++ sqlite.go | 47 ++++++++++++++++++++ sqlite_test.go | 115 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+) diff --git a/error_test.go b/error_test.go index 05166b2..c302925 100644 --- a/error_test.go +++ b/error_test.go @@ -65,6 +65,26 @@ func TestPersistYAML(t *testing.T) { }) } +func TestYAMLErrorPaths(t *testing.T) { + t.Run("Load returns error when YAML path is a directory", func(t *testing.T) { + rl := newTestLimiter(t) + rl.filePath = t.TempDir() + + err := rl.Load() + assert.Error(t, err, "Load should fail when the YAML path is a directory") + }) + + t.Run("Persist returns error when YAML path is a directory", func(t *testing.T) { + rl := newTestLimiter(t) + rl.filePath = t.TempDir() + rl.Quotas["test"] = ModelQuota{MaxRPM: 1} + rl.RecordUsage("test", 1, 1) + + err := rl.Persist() + assert.Error(t, err, "Persist should fail when the YAML path is a directory") + }) +} + func TestSQLiteLoadViaLimiter(t *testing.T) { t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "load-err.db") @@ -142,6 +162,21 @@ func TestNewWithSQLiteErrors(t *testing.T) { }) assert.Error(t, err, "should fail with invalid path") }) + + t.Run("NewWithConfig with sqlite backend fails when default directory cannot be created", func(t *testing.T) { + home := t.TempDir() + blockedHome := filepath.Join(home, "blocked-home") + require.NoError(t, os.WriteFile(blockedHome, []byte("x"), 0o644)) + + t.Setenv("HOME", blockedHome) + t.Setenv("USERPROFILE", "") + t.Setenv("home", "") + + _, err := NewWithConfig(Config{ + Backend: backendSQLite, + }) + assert.Error(t, err, "should fail when the default sqlite directory cannot be created") + }) } func TestSQLiteSaveStateErrors(t *testing.T) { diff --git a/ratelimit.go b/ratelimit.go index 639036c..faf9e53 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -224,6 +224,14 @@ func (rl *RateLimiter) Load() error { // loadSQLite reads quotas and state from the SQLite backend. // Caller must hold the lock. func (rl *RateLimiter) loadSQLite() error { + hasSnapshot, err := rl.sqlite.hasSnapshot() + if err != nil { + return err + } + if !hasSnapshot { + return nil + } + quotas, err := rl.sqlite.loadQuotas() if err != nil { return err diff --git a/sqlite.go b/sqlite.go index 863ede4..275d1bd 100644 --- a/sqlite.go +++ b/sqlite.go @@ -46,6 +46,10 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) { // createSchema creates the tables and indices if they do not already exist. func createSchema(db *sql.DB) error { stmts := []string{ + `CREATE TABLE IF NOT EXISTS snapshot_meta ( + id INTEGER PRIMARY KEY CHECK (id = 1), + has_snapshot INTEGER NOT NULL DEFAULT 0 + )`, `CREATE TABLE IF NOT EXISTS quotas ( model TEXT PRIMARY KEY, max_rpm INTEGER NOT NULL DEFAULT 0, @@ -75,6 +79,31 @@ func createSchema(db *sql.DB) error { return coreerr.E("ratelimit.createSchema", "exec", err) } } + + if err := initialiseSnapshotMeta(db); err != nil { + return err + } + return nil +} + +func initialiseSnapshotMeta(db *sql.DB) error { + if _, err := db.Exec("INSERT OR IGNORE INTO snapshot_meta (id, has_snapshot) VALUES (1, 0)"); err != nil { + return coreerr.E("ratelimit.createSchema", "init snapshot meta", err) + } + + // Older databases do not have snapshot metadata. If any snapshot table + // already contains rows, treat it as an existing persisted snapshot. + if _, err := db.Exec(`UPDATE snapshot_meta + SET has_snapshot = 1 + WHERE id = 1 AND has_snapshot = 0 AND ( + EXISTS (SELECT 1 FROM quotas) OR + EXISTS (SELECT 1 FROM requests) OR + EXISTS (SELECT 1 FROM tokens) OR + EXISTS (SELECT 1 FROM daily) + )`); err != nil { + return coreerr.E("ratelimit.createSchema", "backfill snapshot meta", err) + } + return nil } @@ -138,6 +167,9 @@ func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[strin if err := insertState(tx, state); err != nil { return err } + if err := markSnapshotPersisted(tx); err != nil { + return err + } return commitTx(tx, "ratelimit.saveSnapshot") } @@ -236,6 +268,13 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error { return nil } +func markSnapshotPersisted(tx *sql.Tx) error { + if _, err := tx.Exec("INSERT OR REPLACE INTO snapshot_meta (id, has_snapshot) VALUES (1, 1)"); err != nil { + return coreerr.E("ratelimit.saveSnapshot", "mark snapshot", err) + } + return nil +} + func commitTx(tx *sql.Tx, scope string) error { if err := tx.Commit(); err != nil { return coreerr.E(scope, "commit", err) @@ -243,6 +282,14 @@ func commitTx(tx *sql.Tx, scope string) error { return nil } +func (s *sqliteStore) hasSnapshot() (bool, error) { + var hasSnapshot int + if err := s.db.QueryRow("SELECT has_snapshot FROM snapshot_meta WHERE id = 1").Scan(&hasSnapshot); err != nil { + return false, coreerr.E("ratelimit.hasSnapshot", "query", err) + } + return hasSnapshot != 0, nil +} + // loadState reconstructs the UsageStats map from SQLite tables. func (s *sqliteStore) loadState() (map[string]*UsageStats, error) { result := make(map[string]*UsageStats) diff --git a/sqlite_test.go b/sqlite_test.go index 0122d28..07289ed 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1,6 +1,7 @@ package ratelimit import ( + "database/sql" "os" "path/filepath" "sync" @@ -198,6 +199,35 @@ func TestSQLiteEmptyState_Good(t *testing.T) { assert.Empty(t, state, "should return empty state from fresh DB") } +func TestSQLiteLoadStateSparseRows_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "sparse.db") + store, err := newSQLiteStore(dbPath) + require.NoError(t, err) + defer store.close() + + now := time.Now() + _, err = store.db.Exec("INSERT INTO requests (model, ts) VALUES (?, ?)", "request-only", now.UnixNano()) + require.NoError(t, err) + _, err = store.db.Exec("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)", "token-only", now.Add(time.Second).UnixNano(), 123) + require.NoError(t, err) + + state, err := store.loadState() + require.NoError(t, err) + + require.Contains(t, state, "request-only") + assert.Len(t, state["request-only"].Requests, 1) + assert.Empty(t, state["request-only"].Tokens) + assert.True(t, state["request-only"].DayStart.IsZero()) + assert.Equal(t, 0, state["request-only"].DayCount) + + require.Contains(t, state, "token-only") + assert.Empty(t, state["token-only"].Requests) + assert.Len(t, state["token-only"].Tokens, 1) + assert.Equal(t, 123, state["token-only"].Tokens[0].Count) + assert.True(t, state["token-only"].DayStart.IsZero()) + assert.Equal(t, 0, state["token-only"].DayCount) +} + func TestSQLiteClose_Good(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "close.db") store, err := newSQLiteStore(dbPath) @@ -706,6 +736,34 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) { assert.Empty(t, s.Tokens, "should have no tokens") } +func TestSQLiteSaveSnapshotSkipsNilStateEntries_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "nil-state.db") + store, err := newSQLiteStore(dbPath) + require.NoError(t, err) + defer store.close() + + require.NoError(t, store.saveSnapshot( + map[string]ModelQuota{ + "quota-only": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}, + }, + map[string]*UsageStats{ + "nil-model": nil, + }, + )) + + hasSnapshot, err := store.hasSnapshot() + require.NoError(t, err) + assert.True(t, hasSnapshot) + + state, err := store.loadState() + require.NoError(t, err) + assert.NotContains(t, state, "nil-model") + + quotas, err := store.loadQuotas() + require.NoError(t, err) + assert.Contains(t, quotas, "quota-only") +} + // --- Phase 2: End-to-end with persist cycle --- func TestSQLiteEndToEnd_Good(t *testing.T) { @@ -788,6 +846,63 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) { assert.Equal(t, 1, rl2.Stats("model-b").RPD) } +func TestSQLiteLoadBeforeFirstPersistPreservesConfiguredQuotas_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "fresh-load.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Providers: []Provider{ProviderAnthropic}, + Quotas: map[string]ModelQuota{ + "custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, + }, + }) + require.NoError(t, err) + defer rl.Close() + + require.Contains(t, rl.Quotas, "claude-opus-4") + require.Contains(t, rl.Quotas, "custom-model") + require.NotContains(t, rl.Quotas, "gemini-3-pro-preview") + + require.NoError(t, rl.Load()) + + assert.Contains(t, rl.Quotas, "claude-opus-4", "fresh DB should not wipe configured defaults") + assert.Equal(t, ModelQuota{MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0}, rl.Quotas["claude-opus-4"]) + assert.Contains(t, rl.Quotas, "custom-model", "fresh DB should preserve configured custom quotas") + assert.Equal(t, ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, rl.Quotas["custom-model"]) + assert.Empty(t, rl.State) +} + +func TestSQLiteLoadBackfillsSnapshotMetaForLegacyDB_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Quotas: map[string]ModelQuota{ + "legacy-model": {MaxRPM: 7, MaxTPM: 700, MaxRPD: 70}, + }, + }) + require.NoError(t, err) + + rl.RecordUsage("legacy-model", 30, 40) + require.NoError(t, rl.Persist()) + require.NoError(t, rl.Close()) + + db, err := sql.Open("sqlite", dbPath) + require.NoError(t, err) + _, err = db.Exec("DROP TABLE snapshot_meta") + require.NoError(t, err) + require.NoError(t, db.Close()) + + rl2, err := NewWithSQLiteConfig(dbPath, Config{ + Providers: []Provider{ProviderGemini}, + }) + require.NoError(t, err) + defer rl2.Close() + + require.NoError(t, rl2.Load()) + + assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview") + require.Contains(t, rl2.Quotas, "legacy-model") + assert.Equal(t, ModelQuota{MaxRPM: 7, MaxTPM: 700, MaxRPD: 70}, rl2.Quotas["legacy-model"]) + assert.Equal(t, 1, rl2.Stats("legacy-model").RPD) +} + func TestSQLiteLoadReplacesQuotasWithEmptyPersistedSnapshot_Good(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "empty-quotas.db") rl, err := NewWithSQLiteConfig(dbPath, Config{