fix(sqlite): preserve defaults before first persist

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-23 09:30:15 +00:00
parent 9e715c2cb0
commit 5df6ce127b
4 changed files with 205 additions and 0 deletions

View file

@ -65,6 +65,26 @@ func TestPersistYAML(t *testing.T) {
})
}
func TestYAMLErrorPaths(t *testing.T) {
t.Run("Load returns error when YAML path is a directory", func(t *testing.T) {
rl := newTestLimiter(t)
rl.filePath = t.TempDir()
err := rl.Load()
assert.Error(t, err, "Load should fail when the YAML path is a directory")
})
t.Run("Persist returns error when YAML path is a directory", func(t *testing.T) {
rl := newTestLimiter(t)
rl.filePath = t.TempDir()
rl.Quotas["test"] = ModelQuota{MaxRPM: 1}
rl.RecordUsage("test", 1, 1)
err := rl.Persist()
assert.Error(t, err, "Persist should fail when the YAML path is a directory")
})
}
func TestSQLiteLoadViaLimiter(t *testing.T) {
t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "load-err.db")
@ -142,6 +162,21 @@ func TestNewWithSQLiteErrors(t *testing.T) {
})
assert.Error(t, err, "should fail with invalid path")
})
t.Run("NewWithConfig with sqlite backend fails when default directory cannot be created", func(t *testing.T) {
home := t.TempDir()
blockedHome := filepath.Join(home, "blocked-home")
require.NoError(t, os.WriteFile(blockedHome, []byte("x"), 0o644))
t.Setenv("HOME", blockedHome)
t.Setenv("USERPROFILE", "")
t.Setenv("home", "")
_, err := NewWithConfig(Config{
Backend: backendSQLite,
})
assert.Error(t, err, "should fail when the default sqlite directory cannot be created")
})
}
func TestSQLiteSaveStateErrors(t *testing.T) {

View file

@ -224,6 +224,14 @@ func (rl *RateLimiter) Load() error {
// loadSQLite reads quotas and state from the SQLite backend.
// Caller must hold the lock.
func (rl *RateLimiter) loadSQLite() error {
hasSnapshot, err := rl.sqlite.hasSnapshot()
if err != nil {
return err
}
if !hasSnapshot {
return nil
}
quotas, err := rl.sqlite.loadQuotas()
if err != nil {
return err

View file

@ -46,6 +46,10 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
// createSchema creates the tables and indices if they do not already exist.
func createSchema(db *sql.DB) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS snapshot_meta (
id INTEGER PRIMARY KEY CHECK (id = 1),
has_snapshot INTEGER NOT NULL DEFAULT 0
)`,
`CREATE TABLE IF NOT EXISTS quotas (
model TEXT PRIMARY KEY,
max_rpm INTEGER NOT NULL DEFAULT 0,
@ -75,6 +79,31 @@ func createSchema(db *sql.DB) error {
return coreerr.E("ratelimit.createSchema", "exec", err)
}
}
if err := initialiseSnapshotMeta(db); err != nil {
return err
}
return nil
}
func initialiseSnapshotMeta(db *sql.DB) error {
if _, err := db.Exec("INSERT OR IGNORE INTO snapshot_meta (id, has_snapshot) VALUES (1, 0)"); err != nil {
return coreerr.E("ratelimit.createSchema", "init snapshot meta", err)
}
// Older databases do not have snapshot metadata. If any snapshot table
// already contains rows, treat it as an existing persisted snapshot.
if _, err := db.Exec(`UPDATE snapshot_meta
SET has_snapshot = 1
WHERE id = 1 AND has_snapshot = 0 AND (
EXISTS (SELECT 1 FROM quotas) OR
EXISTS (SELECT 1 FROM requests) OR
EXISTS (SELECT 1 FROM tokens) OR
EXISTS (SELECT 1 FROM daily)
)`); err != nil {
return coreerr.E("ratelimit.createSchema", "backfill snapshot meta", err)
}
return nil
}
@ -138,6 +167,9 @@ func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[strin
if err := insertState(tx, state); err != nil {
return err
}
if err := markSnapshotPersisted(tx); err != nil {
return err
}
return commitTx(tx, "ratelimit.saveSnapshot")
}
@ -236,6 +268,13 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
return nil
}
func markSnapshotPersisted(tx *sql.Tx) error {
if _, err := tx.Exec("INSERT OR REPLACE INTO snapshot_meta (id, has_snapshot) VALUES (1, 1)"); err != nil {
return coreerr.E("ratelimit.saveSnapshot", "mark snapshot", err)
}
return nil
}
func commitTx(tx *sql.Tx, scope string) error {
if err := tx.Commit(); err != nil {
return coreerr.E(scope, "commit", err)
@ -243,6 +282,14 @@ func commitTx(tx *sql.Tx, scope string) error {
return nil
}
func (s *sqliteStore) hasSnapshot() (bool, error) {
var hasSnapshot int
if err := s.db.QueryRow("SELECT has_snapshot FROM snapshot_meta WHERE id = 1").Scan(&hasSnapshot); err != nil {
return false, coreerr.E("ratelimit.hasSnapshot", "query", err)
}
return hasSnapshot != 0, nil
}
// loadState reconstructs the UsageStats map from SQLite tables.
func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
result := make(map[string]*UsageStats)

View file

@ -1,6 +1,7 @@
package ratelimit
import (
"database/sql"
"os"
"path/filepath"
"sync"
@ -198,6 +199,35 @@ func TestSQLiteEmptyState_Good(t *testing.T) {
assert.Empty(t, state, "should return empty state from fresh DB")
}
func TestSQLiteLoadStateSparseRows_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "sparse.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
now := time.Now()
_, err = store.db.Exec("INSERT INTO requests (model, ts) VALUES (?, ?)", "request-only", now.UnixNano())
require.NoError(t, err)
_, err = store.db.Exec("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)", "token-only", now.Add(time.Second).UnixNano(), 123)
require.NoError(t, err)
state, err := store.loadState()
require.NoError(t, err)
require.Contains(t, state, "request-only")
assert.Len(t, state["request-only"].Requests, 1)
assert.Empty(t, state["request-only"].Tokens)
assert.True(t, state["request-only"].DayStart.IsZero())
assert.Equal(t, 0, state["request-only"].DayCount)
require.Contains(t, state, "token-only")
assert.Empty(t, state["token-only"].Requests)
assert.Len(t, state["token-only"].Tokens, 1)
assert.Equal(t, 123, state["token-only"].Tokens[0].Count)
assert.True(t, state["token-only"].DayStart.IsZero())
assert.Equal(t, 0, state["token-only"].DayCount)
}
func TestSQLiteClose_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "close.db")
store, err := newSQLiteStore(dbPath)
@ -706,6 +736,34 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) {
assert.Empty(t, s.Tokens, "should have no tokens")
}
func TestSQLiteSaveSnapshotSkipsNilStateEntries_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "nil-state.db")
store, err := newSQLiteStore(dbPath)
require.NoError(t, err)
defer store.close()
require.NoError(t, store.saveSnapshot(
map[string]ModelQuota{
"quota-only": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
},
map[string]*UsageStats{
"nil-model": nil,
},
))
hasSnapshot, err := store.hasSnapshot()
require.NoError(t, err)
assert.True(t, hasSnapshot)
state, err := store.loadState()
require.NoError(t, err)
assert.NotContains(t, state, "nil-model")
quotas, err := store.loadQuotas()
require.NoError(t, err)
assert.Contains(t, quotas, "quota-only")
}
// --- Phase 2: End-to-end with persist cycle ---
func TestSQLiteEndToEnd_Good(t *testing.T) {
@ -788,6 +846,63 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
}
func TestSQLiteLoadBeforeFirstPersistPreservesConfiguredQuotas_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "fresh-load.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderAnthropic},
Quotas: map[string]ModelQuota{
"custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
},
})
require.NoError(t, err)
defer rl.Close()
require.Contains(t, rl.Quotas, "claude-opus-4")
require.Contains(t, rl.Quotas, "custom-model")
require.NotContains(t, rl.Quotas, "gemini-3-pro-preview")
require.NoError(t, rl.Load())
assert.Contains(t, rl.Quotas, "claude-opus-4", "fresh DB should not wipe configured defaults")
assert.Equal(t, ModelQuota{MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0}, rl.Quotas["claude-opus-4"])
assert.Contains(t, rl.Quotas, "custom-model", "fresh DB should preserve configured custom quotas")
assert.Equal(t, ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, rl.Quotas["custom-model"])
assert.Empty(t, rl.State)
}
func TestSQLiteLoadBackfillsSnapshotMetaForLegacyDB_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "legacy.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{
Quotas: map[string]ModelQuota{
"legacy-model": {MaxRPM: 7, MaxTPM: 700, MaxRPD: 70},
},
})
require.NoError(t, err)
rl.RecordUsage("legacy-model", 30, 40)
require.NoError(t, rl.Persist())
require.NoError(t, rl.Close())
db, err := sql.Open("sqlite", dbPath)
require.NoError(t, err)
_, err = db.Exec("DROP TABLE snapshot_meta")
require.NoError(t, err)
require.NoError(t, db.Close())
rl2, err := NewWithSQLiteConfig(dbPath, Config{
Providers: []Provider{ProviderGemini},
})
require.NoError(t, err)
defer rl2.Close()
require.NoError(t, rl2.Load())
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
require.Contains(t, rl2.Quotas, "legacy-model")
assert.Equal(t, ModelQuota{MaxRPM: 7, MaxTPM: 700, MaxRPD: 70}, rl2.Quotas["legacy-model"])
assert.Equal(t, 1, rl2.Stats("legacy-model").RPD)
}
func TestSQLiteLoadReplacesQuotasWithEmptyPersistedSnapshot_Good(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "empty-quotas.db")
rl, err := NewWithSQLiteConfig(dbPath, Config{