fix(sqlite): preserve defaults before first persist
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
9e715c2cb0
commit
5df6ce127b
4 changed files with 205 additions and 0 deletions
|
|
@ -65,6 +65,26 @@ func TestPersistYAML(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestYAMLErrorPaths(t *testing.T) {
|
||||
t.Run("Load returns error when YAML path is a directory", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = t.TempDir()
|
||||
|
||||
err := rl.Load()
|
||||
assert.Error(t, err, "Load should fail when the YAML path is a directory")
|
||||
})
|
||||
|
||||
t.Run("Persist returns error when YAML path is a directory", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = t.TempDir()
|
||||
rl.Quotas["test"] = ModelQuota{MaxRPM: 1}
|
||||
rl.RecordUsage("test", 1, 1)
|
||||
|
||||
err := rl.Persist()
|
||||
assert.Error(t, err, "Persist should fail when the YAML path is a directory")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadViaLimiter(t *testing.T) {
|
||||
t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "load-err.db")
|
||||
|
|
@ -142,6 +162,21 @@ func TestNewWithSQLiteErrors(t *testing.T) {
|
|||
})
|
||||
assert.Error(t, err, "should fail with invalid path")
|
||||
})
|
||||
|
||||
t.Run("NewWithConfig with sqlite backend fails when default directory cannot be created", func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
blockedHome := filepath.Join(home, "blocked-home")
|
||||
require.NoError(t, os.WriteFile(blockedHome, []byte("x"), 0o644))
|
||||
|
||||
t.Setenv("HOME", blockedHome)
|
||||
t.Setenv("USERPROFILE", "")
|
||||
t.Setenv("home", "")
|
||||
|
||||
_, err := NewWithConfig(Config{
|
||||
Backend: backendSQLite,
|
||||
})
|
||||
assert.Error(t, err, "should fail when the default sqlite directory cannot be created")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLiteSaveStateErrors(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -224,6 +224,14 @@ func (rl *RateLimiter) Load() error {
|
|||
// loadSQLite reads quotas and state from the SQLite backend.
|
||||
// Caller must hold the lock.
|
||||
func (rl *RateLimiter) loadSQLite() error {
|
||||
hasSnapshot, err := rl.sqlite.hasSnapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasSnapshot {
|
||||
return nil
|
||||
}
|
||||
|
||||
quotas, err := rl.sqlite.loadQuotas()
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
47
sqlite.go
47
sqlite.go
|
|
@ -46,6 +46,10 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
|||
// createSchema creates the tables and indices if they do not already exist.
|
||||
func createSchema(db *sql.DB) error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS snapshot_meta (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
has_snapshot INTEGER NOT NULL DEFAULT 0
|
||||
)`,
|
||||
`CREATE TABLE IF NOT EXISTS quotas (
|
||||
model TEXT PRIMARY KEY,
|
||||
max_rpm INTEGER NOT NULL DEFAULT 0,
|
||||
|
|
@ -75,6 +79,31 @@ func createSchema(db *sql.DB) error {
|
|||
return coreerr.E("ratelimit.createSchema", "exec", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := initialiseSnapshotMeta(db); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initialiseSnapshotMeta(db *sql.DB) error {
|
||||
if _, err := db.Exec("INSERT OR IGNORE INTO snapshot_meta (id, has_snapshot) VALUES (1, 0)"); err != nil {
|
||||
return coreerr.E("ratelimit.createSchema", "init snapshot meta", err)
|
||||
}
|
||||
|
||||
// Older databases do not have snapshot metadata. If any snapshot table
|
||||
// already contains rows, treat it as an existing persisted snapshot.
|
||||
if _, err := db.Exec(`UPDATE snapshot_meta
|
||||
SET has_snapshot = 1
|
||||
WHERE id = 1 AND has_snapshot = 0 AND (
|
||||
EXISTS (SELECT 1 FROM quotas) OR
|
||||
EXISTS (SELECT 1 FROM requests) OR
|
||||
EXISTS (SELECT 1 FROM tokens) OR
|
||||
EXISTS (SELECT 1 FROM daily)
|
||||
)`); err != nil {
|
||||
return coreerr.E("ratelimit.createSchema", "backfill snapshot meta", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -138,6 +167,9 @@ func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[strin
|
|||
if err := insertState(tx, state); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := markSnapshotPersisted(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return commitTx(tx, "ratelimit.saveSnapshot")
|
||||
}
|
||||
|
|
@ -236,6 +268,13 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func markSnapshotPersisted(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec("INSERT OR REPLACE INTO snapshot_meta (id, has_snapshot) VALUES (1, 1)"); err != nil {
|
||||
return coreerr.E("ratelimit.saveSnapshot", "mark snapshot", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func commitTx(tx *sql.Tx, scope string) error {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return coreerr.E(scope, "commit", err)
|
||||
|
|
@ -243,6 +282,14 @@ func commitTx(tx *sql.Tx, scope string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *sqliteStore) hasSnapshot() (bool, error) {
|
||||
var hasSnapshot int
|
||||
if err := s.db.QueryRow("SELECT has_snapshot FROM snapshot_meta WHERE id = 1").Scan(&hasSnapshot); err != nil {
|
||||
return false, coreerr.E("ratelimit.hasSnapshot", "query", err)
|
||||
}
|
||||
return hasSnapshot != 0, nil
|
||||
}
|
||||
|
||||
// loadState reconstructs the UsageStats map from SQLite tables.
|
||||
func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
||||
result := make(map[string]*UsageStats)
|
||||
|
|
|
|||
115
sqlite_test.go
115
sqlite_test.go
|
|
@ -1,6 +1,7 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
|
@ -198,6 +199,35 @@ func TestSQLiteEmptyState_Good(t *testing.T) {
|
|||
assert.Empty(t, state, "should return empty state from fresh DB")
|
||||
}
|
||||
|
||||
func TestSQLiteLoadStateSparseRows_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "sparse.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
||||
now := time.Now()
|
||||
_, err = store.db.Exec("INSERT INTO requests (model, ts) VALUES (?, ?)", "request-only", now.UnixNano())
|
||||
require.NoError(t, err)
|
||||
_, err = store.db.Exec("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)", "token-only", now.Add(time.Second).UnixNano(), 123)
|
||||
require.NoError(t, err)
|
||||
|
||||
state, err := store.loadState()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, state, "request-only")
|
||||
assert.Len(t, state["request-only"].Requests, 1)
|
||||
assert.Empty(t, state["request-only"].Tokens)
|
||||
assert.True(t, state["request-only"].DayStart.IsZero())
|
||||
assert.Equal(t, 0, state["request-only"].DayCount)
|
||||
|
||||
require.Contains(t, state, "token-only")
|
||||
assert.Empty(t, state["token-only"].Requests)
|
||||
assert.Len(t, state["token-only"].Tokens, 1)
|
||||
assert.Equal(t, 123, state["token-only"].Tokens[0].Count)
|
||||
assert.True(t, state["token-only"].DayStart.IsZero())
|
||||
assert.Equal(t, 0, state["token-only"].DayCount)
|
||||
}
|
||||
|
||||
func TestSQLiteClose_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "close.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
|
|
@ -706,6 +736,34 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) {
|
|||
assert.Empty(t, s.Tokens, "should have no tokens")
|
||||
}
|
||||
|
||||
func TestSQLiteSaveSnapshotSkipsNilStateEntries_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "nil-state.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
||||
require.NoError(t, store.saveSnapshot(
|
||||
map[string]ModelQuota{
|
||||
"quota-only": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
|
||||
},
|
||||
map[string]*UsageStats{
|
||||
"nil-model": nil,
|
||||
},
|
||||
))
|
||||
|
||||
hasSnapshot, err := store.hasSnapshot()
|
||||
require.NoError(t, err)
|
||||
assert.True(t, hasSnapshot)
|
||||
|
||||
state, err := store.loadState()
|
||||
require.NoError(t, err)
|
||||
assert.NotContains(t, state, "nil-model")
|
||||
|
||||
quotas, err := store.loadQuotas()
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, quotas, "quota-only")
|
||||
}
|
||||
|
||||
// --- Phase 2: End-to-end with persist cycle ---
|
||||
|
||||
func TestSQLiteEndToEnd_Good(t *testing.T) {
|
||||
|
|
@ -788,6 +846,63 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
|||
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
|
||||
}
|
||||
|
||||
func TestSQLiteLoadBeforeFirstPersistPreservesConfiguredQuotas_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "fresh-load.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderAnthropic},
|
||||
Quotas: map[string]ModelQuota{
|
||||
"custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
||||
require.Contains(t, rl.Quotas, "claude-opus-4")
|
||||
require.Contains(t, rl.Quotas, "custom-model")
|
||||
require.NotContains(t, rl.Quotas, "gemini-3-pro-preview")
|
||||
|
||||
require.NoError(t, rl.Load())
|
||||
|
||||
assert.Contains(t, rl.Quotas, "claude-opus-4", "fresh DB should not wipe configured defaults")
|
||||
assert.Equal(t, ModelQuota{MaxRPM: 50, MaxTPM: 40000, MaxRPD: 0}, rl.Quotas["claude-opus-4"])
|
||||
assert.Contains(t, rl.Quotas, "custom-model", "fresh DB should preserve configured custom quotas")
|
||||
assert.Equal(t, ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, rl.Quotas["custom-model"])
|
||||
assert.Empty(t, rl.State)
|
||||
}
|
||||
|
||||
func TestSQLiteLoadBackfillsSnapshotMetaForLegacyDB_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "legacy.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"legacy-model": {MaxRPM: 7, MaxTPM: 700, MaxRPD: 70},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rl.RecordUsage("legacy-model", 30, 40)
|
||||
require.NoError(t, rl.Persist())
|
||||
require.NoError(t, rl.Close())
|
||||
|
||||
db, err := sql.Open("sqlite", dbPath)
|
||||
require.NoError(t, err)
|
||||
_, err = db.Exec("DROP TABLE snapshot_meta")
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
|
||||
rl2, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer rl2.Close()
|
||||
|
||||
require.NoError(t, rl2.Load())
|
||||
|
||||
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
|
||||
require.Contains(t, rl2.Quotas, "legacy-model")
|
||||
assert.Equal(t, ModelQuota{MaxRPM: 7, MaxTPM: 700, MaxRPD: 70}, rl2.Quotas["legacy-model"])
|
||||
assert.Equal(t, 1, rl2.Stats("legacy-model").RPD)
|
||||
}
|
||||
|
||||
func TestSQLiteLoadReplacesQuotasWithEmptyPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "empty-quotas.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue