Merge pull request '[agent/codex] AX v0.8.0 polish pass. Fix ALL violations — banned imports...' (#14) from agent/upgrade-this-package-to-dappco-re-go-cor into dev
This commit is contained in:
commit
27723ce8e9
5 changed files with 390 additions and 262 deletions
156
error_test.go
156
error_test.go
|
|
@ -1,8 +1,7 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -10,8 +9,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSQLiteErrorPaths(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "error.db")
|
||||
func TestError_SQLiteErrorPaths_Bad(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "error.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -39,17 +38,17 @@ func TestSQLiteErrorPaths(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteInitErrors(t *testing.T) {
|
||||
func TestError_SQLiteInitErrors_Bad(t *testing.T) {
|
||||
t.Run("WAL pragma failure", func(t *testing.T) {
|
||||
// This is hard to trigger without mocking sql.DB, but we can try an invalid connection string
|
||||
// modernc.org/sqlite doesn't support all DSN options that might cause PRAGMA to fail but connection to succeed.
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistYAML(t *testing.T) {
|
||||
func TestError_PersistYAML_Good(t *testing.T) {
|
||||
t.Run("successful YAML persist and load", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||
path := testPath(tmpDir, "ratelimits.yaml")
|
||||
rl, _ := New()
|
||||
rl.filePath = path
|
||||
rl.Quotas["test"] = ModelQuota{MaxRPM: 1}
|
||||
|
|
@ -65,9 +64,9 @@ func TestPersistYAML(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadViaLimiter(t *testing.T) {
|
||||
func TestError_SQLiteLoadViaLimiter_Bad(t *testing.T) {
|
||||
t.Run("Load returns error when SQLite DB is closed", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "load-err.db")
|
||||
dbPath := testPath(t.TempDir(), "load-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -79,7 +78,7 @@ func TestSQLiteLoadViaLimiter(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Load returns error when loadState fails", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "load-state-err.db")
|
||||
dbPath := testPath(t.TempDir(), "load-state-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -96,9 +95,9 @@ func TestSQLiteLoadViaLimiter(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLitePersistViaLimiter(t *testing.T) {
|
||||
func TestError_SQLitePersistViaLimiter_Bad(t *testing.T) {
|
||||
t.Run("Persist returns error when SQLite saveQuotas fails", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-err.db")
|
||||
dbPath := testPath(t.TempDir(), "persist-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -113,7 +112,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("Persist returns error when SQLite saveState fails", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-state-err.db")
|
||||
dbPath := testPath(t.TempDir(), "persist-state-err.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -130,7 +129,7 @@ func TestSQLitePersistViaLimiter(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewWithSQLiteErrors(t *testing.T) {
|
||||
func TestError_NewWithSQLite_Bad(t *testing.T) {
|
||||
t.Run("NewWithSQLite with invalid path", func(t *testing.T) {
|
||||
_, err := NewWithSQLite("/nonexistent/deep/nested/dir/test.db")
|
||||
assert.Error(t, err, "should fail with invalid path")
|
||||
|
|
@ -144,9 +143,9 @@ func TestNewWithSQLiteErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteSaveStateErrors(t *testing.T) {
|
||||
func TestError_SQLiteSaveState_Bad(t *testing.T) {
|
||||
t.Run("saveState fails when tokens table is dropped", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "tokens-err.db")
|
||||
dbPath := testPath(t.TempDir(), "tokens-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -166,7 +165,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails when daily table is dropped", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "daily-err.db")
|
||||
dbPath := testPath(t.TempDir(), "daily-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -185,7 +184,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on request insert with renamed column", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "req-insert-err.db")
|
||||
dbPath := testPath(t.TempDir(), "req-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -206,7 +205,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on token insert with renamed column", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "tok-insert-err.db")
|
||||
dbPath := testPath(t.TempDir(), "tok-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -227,7 +226,7 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveState fails on daily insert with renamed column", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "day-insert-err.db")
|
||||
dbPath := testPath(t.TempDir(), "day-insert-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -247,9 +246,9 @@ func TestSQLiteSaveStateErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadStateErrors(t *testing.T) {
|
||||
func TestError_SQLiteLoadState_Bad(t *testing.T) {
|
||||
t.Run("loadState fails when requests table is dropped", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "req-err.db")
|
||||
dbPath := testPath(t.TempDir(), "req-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -271,7 +270,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("loadState fails when tokens table is dropped", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "tok-err.db")
|
||||
dbPath := testPath(t.TempDir(), "tok-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -293,7 +292,7 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("loadState fails when daily table is dropped", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "daily-load-err.db")
|
||||
dbPath := testPath(t.TempDir(), "daily-load-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -314,9 +313,9 @@ func TestSQLiteLoadStateErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteSaveQuotasExecError(t *testing.T) {
|
||||
func TestError_SQLiteSaveQuotasExec_Bad(t *testing.T) {
|
||||
t.Run("saveQuotas fails with renamed column at prepare", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-exec-err.db")
|
||||
dbPath := testPath(t.TempDir(), "quota-exec-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -332,7 +331,7 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("saveQuotas fails at exec via trigger", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-trigger.db")
|
||||
dbPath := testPath(t.TempDir(), "quota-trigger.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -350,9 +349,9 @@ func TestSQLiteSaveQuotasExecError(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteSaveStateExecErrors(t *testing.T) {
|
||||
func TestError_SQLiteSaveStateExec_Bad(t *testing.T) {
|
||||
t.Run("request insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-req.db")
|
||||
dbPath := testPath(t.TempDir(), "trigger-req.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -375,7 +374,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("token insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-tok.db")
|
||||
dbPath := testPath(t.TempDir(), "trigger-tok.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -398,7 +397,7 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("daily insert exec fails via trigger", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "trigger-day.db")
|
||||
dbPath := testPath(t.TempDir(), "trigger-day.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -420,9 +419,9 @@ func TestSQLiteSaveStateExecErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadQuotasScanError(t *testing.T) {
|
||||
func TestError_SQLiteLoadQuotasScan_Bad(t *testing.T) {
|
||||
t.Run("loadQuotas fails with renamed column", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "quota-scan-err.db")
|
||||
dbPath := testPath(t.TempDir(), "quota-scan-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -441,26 +440,29 @@ func TestSQLiteLoadQuotasScanError(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewSQLiteStoreInReadOnlyDir(t *testing.T) {
|
||||
if os.Getuid() == 0 {
|
||||
func TestError_NewSQLiteStoreInReadOnlyDir_Bad(t *testing.T) {
|
||||
if isRootUser() {
|
||||
t.Skip("chmod restrictions do not apply to root")
|
||||
}
|
||||
|
||||
t.Run("fails when parent directory is read-only", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
readonlyDir := filepath.Join(tmpDir, "readonly")
|
||||
require.NoError(t, os.MkdirAll(readonlyDir, 0555))
|
||||
defer os.Chmod(readonlyDir, 0755)
|
||||
readonlyDir := testPath(tmpDir, "readonly")
|
||||
ensureTestDir(t, readonlyDir)
|
||||
setPathMode(t, readonlyDir, 0o555)
|
||||
defer func() {
|
||||
_ = syscall.Chmod(readonlyDir, 0o755)
|
||||
}()
|
||||
|
||||
dbPath := filepath.Join(readonlyDir, "test.db")
|
||||
dbPath := testPath(readonlyDir, "test.db")
|
||||
_, err := newSQLiteStore(dbPath)
|
||||
assert.Error(t, err, "should fail when directory is read-only")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQLiteCreateSchemaError(t *testing.T) {
|
||||
func TestError_SQLiteCreateSchema_Bad(t *testing.T) {
|
||||
t.Run("createSchema fails on closed DB", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "schema-err.db")
|
||||
dbPath := testPath(t.TempDir(), "schema-err.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -473,9 +475,9 @@ func TestSQLiteCreateSchemaError(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadStateScanErrors(t *testing.T) {
|
||||
func TestError_SQLiteLoadStateScan_Bad(t *testing.T) {
|
||||
t.Run("scan daily fails with NULL values", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-daily.db")
|
||||
dbPath := testPath(t.TempDir(), "scan-daily.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -495,7 +497,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("scan requests fails with NULL ts", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-req.db")
|
||||
dbPath := testPath(t.TempDir(), "scan-req.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -520,7 +522,7 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("scan tokens fails with NULL values", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-tok.db")
|
||||
dbPath := testPath(t.TempDir(), "scan-tok.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -545,9 +547,9 @@ func TestSQLiteLoadStateScanErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) {
|
||||
func TestError_SQLiteLoadQuotasScanWithBadSchema_Bad(t *testing.T) {
|
||||
t.Run("scan fails with NULL quota values", func(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "scan-quota.db")
|
||||
dbPath := testPath(t.TempDir(), "scan-quota.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -566,11 +568,11 @@ func TestSQLiteLoadQuotasScanWithBadSchema(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
||||
func TestError_MigrateYAMLToSQLiteWithSaveErrors_Bad(t *testing.T) {
|
||||
t.Run("saveQuotas failure during migration via trigger", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "with-quotas.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrate-quota-err.db")
|
||||
yamlPath := testPath(tmpDir, "with-quotas.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrate-quota-err.db")
|
||||
|
||||
// Write a YAML file with quotas.
|
||||
yamlData := `quotas:
|
||||
|
|
@ -579,7 +581,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
|||
max_tpm: 100
|
||||
max_rpd: 50
|
||||
`
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
|
||||
writeTestFile(t, yamlPath, yamlData)
|
||||
|
||||
// Pre-create DB with a trigger that aborts quota inserts.
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -596,8 +598,8 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
|||
|
||||
t.Run("saveState failure during migration via trigger", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "with-state.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrate-state-err.db")
|
||||
yamlPath := testPath(tmpDir, "with-state.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrate-state-err.db")
|
||||
|
||||
// Write YAML with state.
|
||||
yamlData := `state:
|
||||
|
|
@ -607,7 +609,7 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
|||
day_start: 2026-01-01T00:00:00Z
|
||||
day_count: 1
|
||||
`
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte(yamlData), 0644))
|
||||
writeTestFile(t, yamlPath, yamlData)
|
||||
|
||||
// Pre-create DB with a trigger that aborts daily inserts.
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -622,13 +624,13 @@ func TestMigrateYAMLToSQLiteWithSaveErrors(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) {
|
||||
func TestError_MigrateYAMLToSQLiteNilQuotasAndState_Good(t *testing.T) {
|
||||
t.Run("YAML with empty quotas and state migrates cleanly", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "empty.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("{}"), 0644))
|
||||
yamlPath := testPath(tmpDir, "empty.yaml")
|
||||
writeTestFile(t, yamlPath, "{}")
|
||||
|
||||
sqlitePath := filepath.Join(tmpDir, "empty.db")
|
||||
sqlitePath := testPath(tmpDir, "empty.db")
|
||||
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))
|
||||
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -645,30 +647,18 @@ func TestMigrateYAMLToSQLiteNilQuotasAndState(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewWithConfigUserHomeDirError(t *testing.T) {
|
||||
// Unset HOME to trigger os.UserHomeDir() error.
|
||||
home := os.Getenv("HOME")
|
||||
os.Unsetenv("HOME")
|
||||
// Also unset fallback env vars that UserHomeDir checks.
|
||||
plan9Home := os.Getenv("home")
|
||||
os.Unsetenv("home")
|
||||
userProfile := os.Getenv("USERPROFILE")
|
||||
os.Unsetenv("USERPROFILE")
|
||||
defer func() {
|
||||
os.Setenv("HOME", home)
|
||||
if plan9Home != "" {
|
||||
os.Setenv("home", plan9Home)
|
||||
}
|
||||
if userProfile != "" {
|
||||
os.Setenv("USERPROFILE", userProfile)
|
||||
}
|
||||
}()
|
||||
func TestError_NewWithConfigUserHomeDir_Bad(t *testing.T) {
|
||||
// Clear all supported home env vars so defaultStatePath cannot resolve a home directory.
|
||||
t.Setenv("CORE_HOME", "")
|
||||
t.Setenv("HOME", "")
|
||||
t.Setenv("home", "")
|
||||
t.Setenv("USERPROFILE", "")
|
||||
|
||||
_, err := NewWithConfig(Config{})
|
||||
assert.Error(t, err, "should fail when HOME is unset")
|
||||
}
|
||||
|
||||
func TestPersistMarshalError(t *testing.T) {
|
||||
func TestError_PersistMarshal_Good(t *testing.T) {
|
||||
// yaml.Marshal on a struct with map[string]ModelQuota and map[string]*UsageStats
|
||||
// should not fail in practice. We test the error path by using a type that
|
||||
// yaml.Marshal cannot handle: a channel.
|
||||
|
|
@ -680,20 +670,20 @@ func TestPersistMarshalError(t *testing.T) {
|
|||
assert.NoError(t, rl.Persist(), "valid persist should succeed")
|
||||
}
|
||||
|
||||
func TestMigrateErrorsExtended(t *testing.T) {
|
||||
func TestError_MigrateErrorsExtended_Bad(t *testing.T) {
|
||||
t.Run("unmarshal failure", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "bad.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("invalid: yaml: ["), 0644))
|
||||
err := MigrateYAMLToSQLite(path, filepath.Join(tmpDir, "out.db"))
|
||||
path := testPath(tmpDir, "bad.yaml")
|
||||
writeTestFile(t, path, "invalid: yaml: [")
|
||||
err := MigrateYAMLToSQLite(path, testPath(tmpDir, "out.db"))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "ratelimit.MigrateYAMLToSQLite: unmarshal")
|
||||
})
|
||||
|
||||
t.Run("sqlite open failure", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "ok.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("quotas: {}"), 0644))
|
||||
yamlPath := testPath(tmpDir, "ok.yaml")
|
||||
writeTestFile(t, yamlPath, "quotas: {}")
|
||||
// Use an invalid sqlite path (dir where file should be)
|
||||
err := MigrateYAMLToSQLite(yamlPath, "/dev/null/not-a-db")
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIterators(t *testing.T) {
|
||||
func TestIter_Iterators_Good(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-c": {MaxRPM: 10},
|
||||
|
|
@ -77,7 +77,7 @@ func TestIterators(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestIterEarlyBreak(t *testing.T) {
|
||||
func TestIter_IterEarlyBreak_Good(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 10},
|
||||
|
|
@ -110,7 +110,7 @@ func TestIterEarlyBreak(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestCountTokensFull(t *testing.T) {
|
||||
func TestIter_CountTokensFull_Ugly(t *testing.T) {
|
||||
t.Run("empty model is rejected", func(t *testing.T) {
|
||||
_, err := CountTokens(context.Background(), "key", "", "text")
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
102
ratelimit.go
102
ratelimit.go
|
|
@ -3,11 +3,11 @@ package ratelimit
|
|||
import (
|
||||
"context"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
|
@ -17,6 +17,8 @@ import (
|
|||
)
|
||||
|
||||
// Provider identifies an LLM provider for quota profiles.
|
||||
//
|
||||
// provider := ProviderOpenAI
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
|
|
@ -41,6 +43,8 @@ const (
|
|||
)
|
||||
|
||||
// ModelQuota defines the rate limits for a specific model.
|
||||
//
|
||||
// quota := ModelQuota{MaxRPM: 60, MaxTPM: 90000, MaxRPD: 1000}
|
||||
type ModelQuota struct {
|
||||
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
|
||||
MaxTPM int `yaml:"max_tpm"` // Tokens per minute (0 = unlimited)
|
||||
|
|
@ -48,12 +52,18 @@ type ModelQuota struct {
|
|||
}
|
||||
|
||||
// ProviderProfile bundles model quotas for a provider.
|
||||
//
|
||||
// profile := ProviderProfile{Provider: ProviderGemini, Models: DefaultProfiles()[ProviderGemini].Models}
|
||||
type ProviderProfile struct {
|
||||
// Provider identifies the provider that owns the profile.
|
||||
Provider Provider `yaml:"provider"`
|
||||
// Models maps model names to quotas.
|
||||
Models map[string]ModelQuota `yaml:"models"`
|
||||
}
|
||||
|
||||
// Config controls RateLimiter initialisation.
|
||||
//
|
||||
// cfg := Config{Providers: []Provider{ProviderGemini}, FilePath: "/tmp/ratelimits.yaml"}
|
||||
type Config struct {
|
||||
// FilePath overrides the default state file location.
|
||||
// If empty, defaults to ~/.core/ratelimits.yaml.
|
||||
|
|
@ -73,23 +83,35 @@ type Config struct {
|
|||
}
|
||||
|
||||
// TokenEntry records a token usage event.
|
||||
//
|
||||
// entry := TokenEntry{Time: time.Now(), Count: 512}
|
||||
type TokenEntry struct {
|
||||
Time time.Time `yaml:"time"`
|
||||
Count int `yaml:"count"`
|
||||
}
|
||||
|
||||
// UsageStats tracks usage history for a model.
|
||||
//
|
||||
// stats := UsageStats{DayStart: time.Now(), DayCount: 1}
|
||||
type UsageStats struct {
|
||||
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
|
||||
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
|
||||
// DayStart is the start of the rolling 24-hour window.
|
||||
DayStart time.Time `yaml:"day_start"`
|
||||
// DayCount is the number of requests recorded in the rolling 24-hour window.
|
||||
DayCount int `yaml:"day_count"`
|
||||
}
|
||||
|
||||
// RateLimiter manages rate limits across multiple models.
|
||||
//
|
||||
// rl, err := New()
|
||||
// if err != nil { /* handle error */ }
|
||||
// defer rl.Close()
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
// Quotas holds the configured per-model limits.
|
||||
Quotas map[string]ModelQuota `yaml:"quotas"`
|
||||
// State holds per-model usage windows.
|
||||
State map[string]*UsageStats `yaml:"state"`
|
||||
filePath string
|
||||
sqlite *sqliteStore // non-nil when backend is "sqlite"
|
||||
|
|
@ -97,6 +119,9 @@ type RateLimiter struct {
|
|||
|
||||
// DefaultProfiles returns pre-configured quota profiles for each provider.
|
||||
// Values are based on published rate limits as of Feb 2026.
|
||||
//
|
||||
// profiles := DefaultProfiles()
|
||||
// openAI := profiles[ProviderOpenAI]
|
||||
func DefaultProfiles() map[Provider]ProviderProfile {
|
||||
return map[Provider]ProviderProfile{
|
||||
ProviderGemini: {
|
||||
|
|
@ -140,6 +165,8 @@ func DefaultProfiles() map[Provider]ProviderProfile {
|
|||
|
||||
// New creates a new RateLimiter with Gemini defaults.
|
||||
// This preserves backward compatibility -- existing callers are unaffected.
|
||||
//
|
||||
// rl, err := New()
|
||||
func New() (*RateLimiter, error) {
|
||||
return NewWithConfig(Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
|
|
@ -148,6 +175,8 @@ func New() (*RateLimiter, error) {
|
|||
|
||||
// NewWithConfig creates a RateLimiter from explicit configuration.
|
||||
// If no providers or quotas are specified, Gemini defaults are used.
|
||||
//
|
||||
// rl, err := NewWithConfig(Config{Providers: []Provider{ProviderAnthropic}})
|
||||
func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
||||
backend, err := normaliseBackend(cfg.Backend)
|
||||
if err != nil {
|
||||
|
|
@ -177,6 +206,8 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
|||
}
|
||||
|
||||
// SetQuota sets or updates the quota for a specific model at runtime.
|
||||
//
|
||||
// rl.SetQuota("gpt-4o-mini", ModelQuota{MaxRPM: 60, MaxTPM: 200000})
|
||||
func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -185,6 +216,8 @@ func (rl *RateLimiter) SetQuota(model string, quota ModelQuota) {
|
|||
|
||||
// AddProvider loads all default quotas for a provider.
|
||||
// Existing quotas for models in the profile are overwritten.
|
||||
//
|
||||
// rl.AddProvider(ProviderOpenAI)
|
||||
func (rl *RateLimiter) AddProvider(provider Provider) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -196,6 +229,8 @@ func (rl *RateLimiter) AddProvider(provider Provider) {
|
|||
}
|
||||
|
||||
// Load reads the state from disk (YAML) or database (SQLite).
|
||||
//
|
||||
// if err := rl.Load(); err != nil { /* handle error */ }
|
||||
func (rl *RateLimiter) Load() error {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -205,7 +240,7 @@ func (rl *RateLimiter) Load() error {
|
|||
}
|
||||
|
||||
content, err := readLocalFile(rl.filePath)
|
||||
if os.IsNotExist(err) {
|
||||
if core.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
|
|
@ -238,6 +273,8 @@ func (rl *RateLimiter) loadSQLite() error {
|
|||
|
||||
// Persist writes a snapshot of the state to disk (YAML) or database (SQLite).
|
||||
// It clones the state under a lock and performs I/O without blocking other callers.
|
||||
//
|
||||
// if err := rl.Persist(); err != nil { /* handle error */ }
|
||||
func (rl *RateLimiter) Persist() error {
|
||||
rl.mu.Lock()
|
||||
quotas := maps.Clone(rl.Quotas)
|
||||
|
|
@ -328,6 +365,9 @@ func (rl *RateLimiter) prune(model string) {
|
|||
|
||||
// BackgroundPrune starts a goroutine that periodically prunes all model states.
|
||||
// It returns a function to stop the pruner.
|
||||
//
|
||||
// stop := rl.BackgroundPrune(30 * time.Second)
|
||||
// defer stop()
|
||||
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
||||
if interval <= 0 {
|
||||
return func() {}
|
||||
|
|
@ -354,6 +394,8 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
|||
}
|
||||
|
||||
// CanSend checks if a request can be sent without violating limits.
|
||||
//
|
||||
// ok := rl.CanSend("gemini-3-pro-preview", 1200)
|
||||
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||
if estimatedTokens < 0 {
|
||||
return false
|
||||
|
|
@ -401,6 +443,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
|||
}
|
||||
|
||||
// RecordUsage records a successful API call.
|
||||
//
|
||||
// rl.RecordUsage("gemini-3-pro-preview", 900, 300)
|
||||
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -420,6 +464,8 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
|
|||
}
|
||||
|
||||
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||
//
|
||||
// err := rl.WaitForCapacity(ctx, "gemini-3-pro-preview", 1200)
|
||||
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||
if tokens < 0 {
|
||||
return core.E("ratelimit.WaitForCapacity", "negative tokens", nil)
|
||||
|
|
@ -443,6 +489,8 @@ func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens
|
|||
}
|
||||
|
||||
// Reset clears stats for a model (or all if model is empty).
|
||||
//
|
||||
// rl.Reset("gemini-3-pro-preview")
|
||||
func (rl *RateLimiter) Reset(model string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -455,17 +503,28 @@ func (rl *RateLimiter) Reset(model string) {
|
|||
}
|
||||
|
||||
// ModelStats represents a snapshot of usage.
|
||||
//
|
||||
// stats := rl.Stats("gemini-3-pro-preview")
|
||||
type ModelStats struct {
|
||||
// RPM is the current requests-per-minute usage in the sliding window.
|
||||
RPM int
|
||||
// MaxRPM is the configured requests-per-minute limit.
|
||||
MaxRPM int
|
||||
// TPM is the current tokens-per-minute usage in the sliding window.
|
||||
TPM int
|
||||
// MaxTPM is the configured tokens-per-minute limit.
|
||||
MaxTPM int
|
||||
// RPD is the current requests-per-day usage in the rolling 24-hour window.
|
||||
RPD int
|
||||
// MaxRPD is the configured requests-per-day limit.
|
||||
MaxRPD int
|
||||
// DayStart is the start of the current rolling 24-hour window.
|
||||
DayStart time.Time
|
||||
}
|
||||
|
||||
// Models returns a sorted iterator over all model names tracked by the limiter.
|
||||
//
|
||||
// for model := range rl.Models() { println(model) }
|
||||
func (rl *RateLimiter) Models() iter.Seq[string] {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
|
@ -482,6 +541,8 @@ func (rl *RateLimiter) Models() iter.Seq[string] {
|
|||
}
|
||||
|
||||
// Iter returns a sorted iterator over all model names and their current stats.
|
||||
//
|
||||
// for model, stats := range rl.Iter() { _ = stats; println(model) }
|
||||
func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
|
||||
return func(yield func(string, ModelStats) bool) {
|
||||
stats := rl.AllStats()
|
||||
|
|
@ -494,6 +555,8 @@ func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] {
|
|||
}
|
||||
|
||||
// Stats returns current stats for a model.
|
||||
//
|
||||
// stats := rl.Stats("gemini-3-pro-preview")
|
||||
func (rl *RateLimiter) Stats(model string) ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -521,6 +584,8 @@ func (rl *RateLimiter) Stats(model string) ModelStats {
|
|||
}
|
||||
|
||||
// AllStats returns stats for all tracked models.
|
||||
//
|
||||
// all := rl.AllStats()
|
||||
func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
|
@ -559,6 +624,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
|||
// NewWithSQLite creates a SQLite-backed RateLimiter with Gemini defaults.
|
||||
// The database is created at dbPath if it does not exist. Use Close() to
|
||||
// release the database connection when finished.
|
||||
//
|
||||
// rl, err := NewWithSQLite("/tmp/ratelimits.db")
|
||||
func NewWithSQLite(dbPath string) (*RateLimiter, error) {
|
||||
return NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
|
|
@ -568,6 +635,8 @@ func NewWithSQLite(dbPath string) (*RateLimiter, error) {
|
|||
// NewWithSQLiteConfig creates a SQLite-backed RateLimiter with custom config.
|
||||
// The Backend field in cfg is ignored (always "sqlite"). Use Close() to
|
||||
// release the database connection when finished.
|
||||
//
|
||||
// rl, err := NewWithSQLiteConfig("/tmp/ratelimits.db", Config{Providers: []Provider{ProviderOpenAI}})
|
||||
func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
if err != nil {
|
||||
|
|
@ -582,6 +651,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
|||
// Close releases resources held by the RateLimiter. For YAML-backed
|
||||
// limiters this is a no-op. For SQLite-backed limiters it closes the
|
||||
// database connection.
|
||||
//
|
||||
// defer rl.Close()
|
||||
func (rl *RateLimiter) Close() error {
|
||||
if rl.sqlite != nil {
|
||||
return rl.sqlite.close()
|
||||
|
|
@ -592,6 +663,8 @@ func (rl *RateLimiter) Close() error {
|
|||
// MigrateYAMLToSQLite reads state from a YAML file and writes it to a new
|
||||
// SQLite database. Both quotas and usage state are migrated. The SQLite
|
||||
// database is created if it does not exist.
|
||||
//
|
||||
// err := MigrateYAMLToSQLite("ratelimits.yaml", "ratelimits.db")
|
||||
func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
||||
// Load from YAML.
|
||||
content, err := readLocalFile(yamlPath)
|
||||
|
|
@ -618,6 +691,8 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
|||
}
|
||||
|
||||
// CountTokens calls the Google API to count tokens for a prompt.
|
||||
//
|
||||
// tokens, err := CountTokens(ctx, apiKey, "gemini-3-pro-preview", prompt)
|
||||
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
||||
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
|
||||
}
|
||||
|
|
@ -722,9 +797,9 @@ func normaliseBackend(backend string) (string, error) {
|
|||
}
|
||||
|
||||
func defaultStatePath(backend string) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
home := currentHomeDir()
|
||||
if home == "" {
|
||||
return "", core.E("ratelimit.defaultStatePath", "home dir unavailable", nil)
|
||||
}
|
||||
|
||||
fileName := defaultYAMLStateFile
|
||||
|
|
@ -735,6 +810,15 @@ func defaultStatePath(backend string) (string, error) {
|
|||
return core.Path(home, defaultStateDirName, fileName), nil
|
||||
}
|
||||
|
||||
func currentHomeDir() string {
|
||||
for _, key := range []string{"CORE_HOME", "HOME", "home", "USERPROFILE"} {
|
||||
if value := core.Trim(core.Env(key)); value != "" {
|
||||
return value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func safeTokenSum(a, b int) int {
|
||||
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
|
||||
}
|
||||
|
|
@ -762,7 +846,7 @@ func safeTokenTotal(tokens []TokenEntry) int {
|
|||
|
||||
func countTokensURL(baseURL, model string) (string, error) {
|
||||
if core.Trim(model) == "" {
|
||||
return "", core.NewError("empty model")
|
||||
return "", core.E("ratelimit.countTokensURL", "empty model", nil)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(baseURL)
|
||||
|
|
@ -770,7 +854,7 @@ func countTokensURL(baseURL, model string) (string, error) {
|
|||
return "", err
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", core.NewError("invalid base url")
|
||||
return "", core.E("ratelimit.countTokensURL", "invalid base url", nil)
|
||||
}
|
||||
|
||||
return core.Concat(core.TrimSuffix(parsed.String(), "/"), "/v1beta/models/", url.PathEscape(model), ":countTokens"), nil
|
||||
|
|
@ -803,7 +887,7 @@ func readLocalFile(path string) (string, error) {
|
|||
|
||||
content, ok := result.Value.(string)
|
||||
if !ok {
|
||||
return "", core.NewError("read returned non-string")
|
||||
return "", core.E("ratelimit.readLocalFile", "read returned non-string", nil)
|
||||
}
|
||||
return content, nil
|
||||
}
|
||||
|
|
@ -828,5 +912,5 @@ func resultError(result core.Result) error {
|
|||
if result.Value == nil {
|
||||
return nil
|
||||
}
|
||||
return core.NewError(core.Sprint(result.Value))
|
||||
return core.E("ratelimit.resultError", core.Sprint(result.Value), nil)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,28 +2,92 @@ package ratelimit
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
core "dappco.re/go/core"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func testPath(parts ...string) string {
|
||||
return core.Path(parts...)
|
||||
}
|
||||
|
||||
func pathExists(path string) bool {
|
||||
var fs core.Fs
|
||||
return fs.Exists(path)
|
||||
}
|
||||
|
||||
func writeTestFile(tb testing.TB, path, content string) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, writeLocalFile(path, content))
|
||||
}
|
||||
|
||||
func ensureTestDir(tb testing.TB, path string) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, ensureDir(path))
|
||||
}
|
||||
|
||||
func setPathMode(tb testing.TB, path string, mode uint32) {
|
||||
tb.Helper()
|
||||
require.NoError(tb, syscall.Chmod(path, mode))
|
||||
}
|
||||
|
||||
func overwriteTestFile(tb testing.TB, path, content string) {
|
||||
tb.Helper()
|
||||
|
||||
var fs core.Fs
|
||||
writer := fs.Create(path)
|
||||
require.NoError(tb, resultError(writer))
|
||||
require.NoError(tb, resultError(core.WriteAll(writer.Value, content)))
|
||||
}
|
||||
|
||||
func isRootUser() bool {
|
||||
return syscall.Geteuid() == 0
|
||||
}
|
||||
|
||||
func repeatString(part string, count int) string {
|
||||
builder := core.NewBuilder()
|
||||
for i := 0; i < count; i++ {
|
||||
builder.WriteString(part)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func substringCount(s, substr string) int {
|
||||
if substr == "" {
|
||||
return 0
|
||||
}
|
||||
return len(core.Split(s, substr)) - 1
|
||||
}
|
||||
|
||||
func decodeJSONBody(tb testing.TB, r io.Reader, target any) {
|
||||
tb.Helper()
|
||||
|
||||
data, err := io.ReadAll(r)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, resultError(core.JSONUnmarshal(data, target)))
|
||||
}
|
||||
|
||||
func writeJSONBody(tb testing.TB, w io.Writer, value any) {
|
||||
tb.Helper()
|
||||
|
||||
_, err := io.WriteString(w, core.JSONMarshalString(value))
|
||||
require.NoError(tb, err)
|
||||
}
|
||||
|
||||
// newTestLimiter returns a RateLimiter with file path set to a temp directory.
|
||||
func newTestLimiter(t *testing.T) *RateLimiter {
|
||||
t.Helper()
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml")
|
||||
rl.filePath = testPath(t.TempDir(), "ratelimits.yaml")
|
||||
return rl
|
||||
}
|
||||
|
||||
|
|
@ -41,7 +105,7 @@ func (errReader) Read([]byte) (int, error) {
|
|||
|
||||
// --- Phase 0: CanSend boundary conditions ---
|
||||
|
||||
func TestCanSend(t *testing.T) {
|
||||
func TestRatelimit_CanSend_Good(t *testing.T) {
|
||||
t.Run("fresh state allows send", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "test-model"
|
||||
|
|
@ -189,7 +253,7 @@ func TestCanSend(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Sliding window / prune tests ---
|
||||
|
||||
func TestPrune(t *testing.T) {
|
||||
func TestRatelimit_Prune_Good(t *testing.T) {
|
||||
t.Run("removes old entries", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "test-prune"
|
||||
|
|
@ -304,7 +368,7 @@ func TestPrune(t *testing.T) {
|
|||
|
||||
// --- Phase 0: RecordUsage ---
|
||||
|
||||
func TestRecordUsage(t *testing.T) {
|
||||
func TestRatelimit_RecordUsage_Good(t *testing.T) {
|
||||
t.Run("records into fresh state", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "record-fresh"
|
||||
|
|
@ -375,7 +439,7 @@ func TestRecordUsage(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Reset ---
|
||||
|
||||
func TestReset(t *testing.T) {
|
||||
func TestRatelimit_Reset_Good(t *testing.T) {
|
||||
t.Run("reset single model", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.RecordUsage("model-a", 10, 10)
|
||||
|
|
@ -409,7 +473,7 @@ func TestReset(t *testing.T) {
|
|||
|
||||
// --- Phase 0: WaitForCapacity ---
|
||||
|
||||
func TestWaitForCapacity(t *testing.T) {
|
||||
func TestRatelimit_WaitForCapacity_Good(t *testing.T) {
|
||||
t.Run("context cancelled returns error", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "wait-cancel"
|
||||
|
|
@ -467,7 +531,7 @@ func TestWaitForCapacity(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNilUsageStats(t *testing.T) {
|
||||
func TestRatelimit_NilUsageStats_Ugly(t *testing.T) {
|
||||
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "nil-cansend"
|
||||
|
|
@ -514,7 +578,7 @@ func TestNilUsageStats(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Stats ---
|
||||
|
||||
func TestStats(t *testing.T) {
|
||||
func TestRatelimit_Stats_Good(t *testing.T) {
|
||||
t.Run("returns stats for known model with usage", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "stats-test"
|
||||
|
|
@ -554,7 +618,7 @@ func TestStats(t *testing.T) {
|
|||
|
||||
// --- Phase 0: AllStats ---
|
||||
|
||||
func TestAllStats(t *testing.T) {
|
||||
func TestRatelimit_AllStats_Good(t *testing.T) {
|
||||
t.Run("includes all default quotas plus state-only models", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.RecordUsage("gemini-3-pro-preview", 1000, 500)
|
||||
|
|
@ -612,10 +676,10 @@ func TestAllStats(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Persist and Load ---
|
||||
|
||||
func TestPersistAndLoad(t *testing.T) {
|
||||
func TestRatelimit_PersistAndLoad_Ugly(t *testing.T) {
|
||||
t.Run("round-trip preserves state", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||
path := testPath(tmpDir, "ratelimits.yaml")
|
||||
|
||||
rl1, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -638,7 +702,7 @@ func TestPersistAndLoad(t *testing.T) {
|
|||
|
||||
t.Run("load from non-existent file is not an error", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = filepath.Join(t.TempDir(), "does-not-exist.yaml")
|
||||
rl.filePath = testPath(t.TempDir(), "does-not-exist.yaml")
|
||||
|
||||
err := rl.Load()
|
||||
assert.NoError(t, err, "loading non-existent file should not error")
|
||||
|
|
@ -646,8 +710,8 @@ func TestPersistAndLoad(t *testing.T) {
|
|||
|
||||
t.Run("load from corrupt YAML returns error", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "corrupt.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("{{{{invalid yaml!!!!"), 0644))
|
||||
path := testPath(tmpDir, "corrupt.yaml")
|
||||
writeTestFile(t, path, "{{{{invalid yaml!!!!")
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -657,13 +721,13 @@ func TestPersistAndLoad(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("load from unreadable file returns error", func(t *testing.T) {
|
||||
if os.Getuid() == 0 {
|
||||
if isRootUser() {
|
||||
t.Skip("chmod 000 does not restrict root")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "unreadable.yaml")
|
||||
require.NoError(t, os.WriteFile(path, []byte("quotas: {}"), 0644))
|
||||
require.NoError(t, os.Chmod(path, 0000))
|
||||
path := testPath(tmpDir, "unreadable.yaml")
|
||||
writeTestFile(t, path, "quotas: {}")
|
||||
setPathMode(t, path, 0o000)
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -672,12 +736,12 @@ func TestPersistAndLoad(t *testing.T) {
|
|||
assert.Error(t, err, "unreadable file should produce an error")
|
||||
|
||||
// Clean up permissions for temp dir cleanup
|
||||
_ = os.Chmod(path, 0644)
|
||||
_ = syscall.Chmod(path, 0o644)
|
||||
})
|
||||
|
||||
t.Run("persist to nested non-existent directory creates it", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "nested", "deep", "ratelimits.yaml")
|
||||
path := testPath(tmpDir, "nested", "deep", "ratelimits.yaml")
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -686,32 +750,32 @@ func TestPersistAndLoad(t *testing.T) {
|
|||
err := rl.Persist()
|
||||
assert.NoError(t, err, "should create nested directories")
|
||||
|
||||
_, statErr := os.Stat(path)
|
||||
assert.NoError(t, statErr, "file should exist")
|
||||
assert.True(t, pathExists(path), "file should exist")
|
||||
})
|
||||
|
||||
t.Run("persist to unwritable directory returns error", func(t *testing.T) {
|
||||
if os.Getuid() == 0 {
|
||||
if isRootUser() {
|
||||
t.Skip("chmod 0555 does not restrict root")
|
||||
}
|
||||
tmpDir := t.TempDir()
|
||||
unwritable := filepath.Join(tmpDir, "readonly")
|
||||
require.NoError(t, os.MkdirAll(unwritable, 0555))
|
||||
unwritable := testPath(tmpDir, "readonly")
|
||||
ensureTestDir(t, unwritable)
|
||||
setPathMode(t, unwritable, 0o555)
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = filepath.Join(unwritable, "sub", "ratelimits.yaml")
|
||||
rl.filePath = testPath(unwritable, "sub", "ratelimits.yaml")
|
||||
|
||||
err := rl.Persist()
|
||||
assert.Error(t, err, "should fail when directory is unwritable")
|
||||
|
||||
// Clean up
|
||||
_ = os.Chmod(unwritable, 0755)
|
||||
_ = syscall.Chmod(unwritable, 0o755)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Default quotas ---
|
||||
|
||||
func TestDefaultQuotas(t *testing.T) {
|
||||
func TestRatelimit_DefaultQuotas_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
|
||||
tests := []struct {
|
||||
|
|
@ -740,7 +804,7 @@ func TestDefaultQuotas(t *testing.T) {
|
|||
|
||||
// --- Phase 0: Concurrent access (race test) ---
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentAccess_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "concurrent-test"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 1000, MaxTPM: 10000000, MaxRPD: 10000}
|
||||
|
|
@ -766,7 +830,7 @@ func TestConcurrentAccess(t *testing.T) {
|
|||
assert.Equal(t, expected, stats.RPD, "all recordings should be counted")
|
||||
}
|
||||
|
||||
func TestConcurrentResetAndRecord(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentResetAndRecord_Ugly(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "concurrent-reset"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10000, MaxTPM: 100000000, MaxRPD: 100000}
|
||||
|
|
@ -804,7 +868,7 @@ func TestConcurrentResetAndRecord(t *testing.T) {
|
|||
// No assertion needed -- if we get here without -race flagging, mutex is sound
|
||||
}
|
||||
|
||||
func TestBackgroundPrune(t *testing.T) {
|
||||
func TestRatelimit_BackgroundPrune_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "prune-me"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 100}
|
||||
|
|
@ -843,7 +907,7 @@ func TestBackgroundPrune(t *testing.T) {
|
|||
|
||||
// --- Phase 0: CountTokens (with mock HTTP server) ---
|
||||
|
||||
func TestCountTokens(t *testing.T) {
|
||||
func TestRatelimit_CountTokens_Ugly(t *testing.T) {
|
||||
t.Run("successful token count", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
|
|
@ -858,13 +922,13 @@ func TestCountTokens(t *testing.T) {
|
|||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||
decodeJSONBody(t, r.Body, &body)
|
||||
require.Len(t, body.Contents, 1)
|
||||
require.Len(t, body.Contents[0].Parts, 1)
|
||||
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 42})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -878,7 +942,7 @@ func TestCountTokens(t *testing.T) {
|
|||
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
|
||||
assert.Empty(t, r.URL.RawQuery)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 7})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -888,10 +952,10 @@ func TestCountTokens(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("API error body is truncated", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
|
||||
largeBody := repeatString("x", countTokensErrorBodyLimit+256)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := fmt.Fprint(w, largeBody)
|
||||
_, err := io.WriteString(w, largeBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
|
@ -899,7 +963,7 @@ func TestCountTokens(t *testing.T) {
|
|||
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "api error status 401")
|
||||
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||
assert.True(t, substringCount(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||
assert.Contains(t, err.Error(), "...")
|
||||
})
|
||||
|
||||
|
|
@ -953,7 +1017,7 @@ func TestCountTokens(t *testing.T) {
|
|||
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
|
||||
writeJSONBody(t, w, map[string]int{"totalTokens": 11})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
|
|
@ -969,8 +1033,8 @@ func TestCountTokens(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestPersistSkipsNilState(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "nil-state.yaml")
|
||||
func TestRatelimit_PersistSkipsNilState_Good(t *testing.T) {
|
||||
path := testPath(t.TempDir(), "nil-state.yaml")
|
||||
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -986,7 +1050,7 @@ func TestPersistSkipsNilState(t *testing.T) {
|
|||
assert.NotContains(t, rl2.State, "nil-model")
|
||||
}
|
||||
|
||||
func TestTokenTotals(t *testing.T) {
|
||||
func TestRatelimit_TokenTotals_Good(t *testing.T) {
|
||||
maxInt := int(^uint(0) >> 1)
|
||||
|
||||
assert.Equal(t, 25, safeTokenSum(-100, 25))
|
||||
|
|
@ -1056,7 +1120,7 @@ func BenchmarkCanSendConcurrent(b *testing.B) {
|
|||
|
||||
// --- Phase 1: Provider profiles and NewWithConfig ---
|
||||
|
||||
func TestDefaultProfiles(t *testing.T) {
|
||||
func TestRatelimit_DefaultProfiles_Good(t *testing.T) {
|
||||
profiles := DefaultProfiles()
|
||||
|
||||
t.Run("contains all four providers", func(t *testing.T) {
|
||||
|
|
@ -1097,10 +1161,10 @@ func TestDefaultProfiles(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestNewWithConfig(t *testing.T) {
|
||||
func TestRatelimit_NewWithConfig_Ugly(t *testing.T) {
|
||||
t.Run("empty config defaults to Gemini", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -1110,7 +1174,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
t.Run("single provider loads only its models", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderOpenAI},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1124,7 +1188,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
t.Run("multiple providers merge models", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1140,7 +1204,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
t.Run("explicit quotas override provider defaults", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderGemini},
|
||||
Quotas: map[string]ModelQuota{
|
||||
"gemini-3-pro-preview": {MaxRPM: 999, MaxTPM: 888, MaxRPD: 777},
|
||||
|
|
@ -1156,7 +1220,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
t.Run("explicit quotas without providers", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Quotas: map[string]ModelQuota{
|
||||
"my-custom-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
|
||||
},
|
||||
|
|
@ -1169,7 +1233,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("custom file path is respected", func(t *testing.T) {
|
||||
customPath := filepath.Join(t.TempDir(), "custom", "limits.yaml")
|
||||
customPath := testPath(t.TempDir(), "custom", "limits.yaml")
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: customPath,
|
||||
Providers: []Provider{ProviderLocal},
|
||||
|
|
@ -1179,13 +1243,12 @@ func TestNewWithConfig(t *testing.T) {
|
|||
rl.RecordUsage("test", 1, 1)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, statErr := os.Stat(customPath)
|
||||
assert.NoError(t, statErr, "file should be created at custom path")
|
||||
assert.True(t, pathExists(customPath), "file should be created at custom path")
|
||||
})
|
||||
|
||||
t.Run("unknown provider is silently skipped", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{"nonexistent-provider"},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1194,7 +1257,7 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
t.Run("local provider with custom quotas", func(t *testing.T) {
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "test.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "test.yaml"),
|
||||
Providers: []Provider{ProviderLocal},
|
||||
Quotas: map[string]ModelQuota{
|
||||
"llama-3.3-70b": {MaxRPM: 5, MaxTPM: 50000, MaxRPD: 0},
|
||||
|
|
@ -1224,11 +1287,11 @@ func TestNewWithConfig(t *testing.T) {
|
|||
|
||||
rl, err := NewWithConfig(Config{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||
assert.Equal(t, testPath(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewBackwardCompatibility(t *testing.T) {
|
||||
func TestRatelimit_NewBackwardCompatibility_Good(t *testing.T) {
|
||||
// New() should produce the exact same result as before Phase 1
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1251,7 +1314,7 @@ func TestNewBackwardCompatibility(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSetQuota(t *testing.T) {
|
||||
func TestRatelimit_SetQuota_Good(t *testing.T) {
|
||||
t.Run("adds new model quota", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.SetQuota("custom-model", ModelQuota{MaxRPM: 42, MaxTPM: 9999, MaxRPD: 100})
|
||||
|
|
@ -1279,7 +1342,7 @@ func TestSetQuota(t *testing.T) {
|
|||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
model := fmt.Sprintf("model-%d", n)
|
||||
model := core.Sprintf("model-%d", n)
|
||||
rl.SetQuota(model, ModelQuota{MaxRPM: n, MaxTPM: n * 100, MaxRPD: n * 10})
|
||||
}(i)
|
||||
}
|
||||
|
|
@ -1289,7 +1352,7 @@ func TestSetQuota(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestAddProvider(t *testing.T) {
|
||||
func TestRatelimit_AddProvider_Good(t *testing.T) {
|
||||
t.Run("adds OpenAI models to existing limiter", func(t *testing.T) {
|
||||
rl := newTestLimiter(t) // starts with Gemini defaults
|
||||
geminiCount := len(rl.Quotas)
|
||||
|
|
@ -1351,7 +1414,7 @@ func TestAddProvider(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestProviderConstants(t *testing.T) {
|
||||
func TestRatelimit_ProviderConstants_Good(t *testing.T) {
|
||||
// Verify the string values are stable (they may be used in YAML configs)
|
||||
assert.Equal(t, Provider("gemini"), ProviderGemini)
|
||||
assert.Equal(t, Provider("openai"), ProviderOpenAI)
|
||||
|
|
@ -1361,7 +1424,7 @@ func TestProviderConstants(t *testing.T) {
|
|||
|
||||
// --- Phase 0 addendum: Additional concurrent and multi-model race tests ---
|
||||
|
||||
func TestConcurrentMultipleModels(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentMultipleModels_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
models := []string{"model-a", "model-b", "model-c", "model-d", "model-e"}
|
||||
for _, m := range models {
|
||||
|
|
@ -1391,9 +1454,9 @@ func TestConcurrentMultipleModels(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConcurrentPersistAndLoad(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentPersistAndLoad_Ugly(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "concurrent.yaml")
|
||||
path := testPath(tmpDir, "concurrent.yaml")
|
||||
|
||||
rl := newTestLimiter(t)
|
||||
rl.filePath = path
|
||||
|
|
@ -1425,7 +1488,7 @@ func TestConcurrentPersistAndLoad(t *testing.T) {
|
|||
// No panics or data races = pass
|
||||
}
|
||||
|
||||
func TestConcurrentAllStatsAndRecordUsage(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentAllStatsAndRecordUsage_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
models := []string{"stats-a", "stats-b", "stats-c"}
|
||||
for _, m := range models {
|
||||
|
|
@ -1456,7 +1519,7 @@ func TestConcurrentAllStatsAndRecordUsage(t *testing.T) {
|
|||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestConcurrentWaitForCapacityAndRecordUsage(t *testing.T) {
|
||||
func TestRatelimit_ConcurrentWaitForCapacityAndRecordUsage_Good(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "race-wait"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 10000000, MaxRPD: 10000}
|
||||
|
|
@ -1553,7 +1616,7 @@ func BenchmarkAllStats(b *testing.B) {
|
|||
|
||||
func BenchmarkPersist(b *testing.B) {
|
||||
tmpDir := b.TempDir()
|
||||
path := filepath.Join(tmpDir, "bench.yaml")
|
||||
path := testPath(tmpDir, "bench.yaml")
|
||||
|
||||
rl, _ := New()
|
||||
rl.filePath = path
|
||||
|
|
@ -1574,10 +1637,10 @@ func BenchmarkPersist(b *testing.B) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestEndToEndMultiProvider(t *testing.T) {
|
||||
func TestRatelimit_EndToEndMultiProvider_Good(t *testing.T) {
|
||||
// Simulate a real-world scenario: limiter for both Gemini and Anthropic
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "multi.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "multi.yaml"),
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
|
|||
155
sqlite_test.go
155
sqlite_test.go
|
|
@ -1,8 +1,6 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -14,18 +12,17 @@ import (
|
|||
|
||||
// --- Phase 2: SQLite basic tests ---
|
||||
|
||||
func TestNewSQLiteStore_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
func TestSQLite_NewSQLiteStore_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "test.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
||||
// Verify the database file was created.
|
||||
_, statErr := os.Stat(dbPath)
|
||||
assert.NoError(t, statErr, "database file should exist")
|
||||
assert.True(t, pathExists(dbPath), "database file should exist")
|
||||
}
|
||||
|
||||
func TestNewSQLiteStore_Bad(t *testing.T) {
|
||||
func TestSQLite_NewSQLiteStore_Bad(t *testing.T) {
|
||||
t.Run("invalid path returns error", func(t *testing.T) {
|
||||
// Path inside a non-existent directory with no parent.
|
||||
_, err := newSQLiteStore("/nonexistent/deep/nested/dir/test.db")
|
||||
|
|
@ -33,8 +30,8 @@ func TestNewSQLiteStore_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSQLiteQuotasRoundTrip_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "quotas.db")
|
||||
func TestSQLite_QuotasRoundTrip_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "quotas.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -60,8 +57,8 @@ func TestSQLiteQuotasRoundTrip_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSQLiteQuotasUpsert_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "upsert.db")
|
||||
func TestSQLite_QuotasUpsert_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "upsert.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -85,8 +82,8 @@ func TestSQLiteQuotasUpsert_Good(t *testing.T) {
|
|||
assert.Equal(t, 777, q.MaxRPD, "should have updated RPD")
|
||||
}
|
||||
|
||||
func TestSQLiteStateRoundTrip_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "state.db")
|
||||
func TestSQLite_StateRoundTrip_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "state.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -144,8 +141,8 @@ func TestSQLiteStateRoundTrip_Good(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSQLiteStateOverwrite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "overwrite.db")
|
||||
func TestSQLite_StateOverwrite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "overwrite.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -182,8 +179,8 @@ func TestSQLiteStateOverwrite_Good(t *testing.T) {
|
|||
assert.Len(t, b.Requests, 1)
|
||||
}
|
||||
|
||||
func TestSQLiteEmptyState_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "empty.db")
|
||||
func TestSQLite_EmptyState_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "empty.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -198,8 +195,8 @@ func TestSQLiteEmptyState_Good(t *testing.T) {
|
|||
assert.Empty(t, state, "should return empty state from fresh DB")
|
||||
}
|
||||
|
||||
func TestSQLiteClose_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "close.db")
|
||||
func TestSQLite_Close_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "close.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -208,8 +205,8 @@ func TestSQLiteClose_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: SQLite integration tests ---
|
||||
|
||||
func TestNewWithSQLite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "limiter.db")
|
||||
func TestSQLite_NewWithSQLite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "limiter.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
|
@ -222,8 +219,8 @@ func TestNewWithSQLite_Good(t *testing.T) {
|
|||
assert.NotNil(t, rl.sqlite, "SQLite store should be initialised")
|
||||
}
|
||||
|
||||
func TestNewWithSQLiteConfig_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "config.db")
|
||||
func TestSQLite_NewWithSQLiteConfig_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "config.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderAnthropic},
|
||||
Quotas: map[string]ModelQuota{
|
||||
|
|
@ -243,8 +240,8 @@ func TestNewWithSQLiteConfig_Good(t *testing.T) {
|
|||
assert.False(t, hasGemini, "should not have Gemini models")
|
||||
}
|
||||
|
||||
func TestSQLitePersistAndLoad_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist.db")
|
||||
func TestSQLite_PersistAndLoad_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -272,8 +269,8 @@ func TestSQLitePersistAndLoad_Good(t *testing.T) {
|
|||
assert.Equal(t, 500, stats.MaxRPD)
|
||||
}
|
||||
|
||||
func TestSQLitePersistMultipleModels_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "multi.db")
|
||||
func TestSQLite_PersistMultipleModels_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "multi.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini, ProviderAnthropic},
|
||||
})
|
||||
|
|
@ -302,8 +299,8 @@ func TestSQLitePersistMultipleModels_Good(t *testing.T) {
|
|||
assert.Equal(t, 400, claude.TPM)
|
||||
}
|
||||
|
||||
func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "record.db")
|
||||
func TestSQLite_RecordUsageThenPersistReload_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "record.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -340,7 +337,7 @@ func TestSQLiteRecordUsageThenPersistReload_Good(t *testing.T) {
|
|||
assert.Equal(t, 1000, stats2.TPM, "TPM should survive reload")
|
||||
}
|
||||
|
||||
func TestSQLiteClose_Good_NoOp(t *testing.T) {
|
||||
func TestSQLite_CloseNoOp_Good(t *testing.T) {
|
||||
// Close on YAML-backed limiter is a no-op.
|
||||
rl := newTestLimiter(t)
|
||||
assert.NoError(t, rl.Close(), "Close on YAML limiter should be no-op")
|
||||
|
|
@ -348,8 +345,8 @@ func TestSQLiteClose_Good_NoOp(t *testing.T) {
|
|||
|
||||
// --- Phase 2: Concurrent SQLite ---
|
||||
|
||||
func TestSQLiteConcurrent_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "concurrent.db")
|
||||
func TestSQLite_Concurrent_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "concurrent.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
|
@ -398,10 +395,10 @@ func TestSQLiteConcurrent_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: YAML backward compatibility ---
|
||||
|
||||
func TestYAMLBackwardCompat_Good(t *testing.T) {
|
||||
func TestSQLite_YAMLBackwardCompat_Good(t *testing.T) {
|
||||
// Verify that the default YAML backend still works after SQLite additions.
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "compat.yaml")
|
||||
path := testPath(tmpDir, "compat.yaml")
|
||||
|
||||
rl1, err := New()
|
||||
require.NoError(t, err)
|
||||
|
|
@ -425,18 +422,18 @@ func TestYAMLBackwardCompat_Good(t *testing.T) {
|
|||
assert.Equal(t, 200, stats.TPM)
|
||||
}
|
||||
|
||||
func TestConfigBackendDefault_Good(t *testing.T) {
|
||||
func TestSQLite_ConfigBackendDefault_Good(t *testing.T) {
|
||||
// Empty Backend string should default to YAML behaviour.
|
||||
rl, err := NewWithConfig(Config{
|
||||
FilePath: filepath.Join(t.TempDir(), "default.yaml"),
|
||||
FilePath: testPath(t.TempDir(), "default.yaml"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
|
||||
}
|
||||
|
||||
func TestConfigBackendSQLite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
|
||||
func TestSQLite_ConfigBackendSQLite_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "config-backend.db")
|
||||
rl, err := NewWithConfig(Config{
|
||||
Backend: backendSQLite,
|
||||
FilePath: dbPath,
|
||||
|
|
@ -451,11 +448,10 @@ func TestConfigBackendSQLite_Good(t *testing.T) {
|
|||
rl.RecordUsage("backend-model", 10, 10)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, statErr := os.Stat(dbPath)
|
||||
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
|
||||
assert.True(t, pathExists(dbPath), "sqlite backend should persist to the configured DB path")
|
||||
}
|
||||
|
||||
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||
func TestSQLite_ConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", "")
|
||||
|
|
@ -470,16 +466,15 @@ func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
|||
require.NotNil(t, rl.sqlite)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
|
||||
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
|
||||
assert.True(t, pathExists(testPath(home, defaultStateDirName, defaultSQLiteStateFile)), "sqlite backend should use the default home DB path")
|
||||
}
|
||||
|
||||
// --- Phase 2: MigrateYAMLToSQLite ---
|
||||
|
||||
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
||||
func TestSQLite_MigrateYAMLToSQLite_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "state.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "migrated.db")
|
||||
yamlPath := testPath(tmpDir, "state.yaml")
|
||||
sqlitePath := testPath(tmpDir, "migrated.db")
|
||||
|
||||
// Create a YAML-backed limiter with state.
|
||||
rl, err := New()
|
||||
|
|
@ -515,26 +510,26 @@ func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
|||
assert.Equal(t, 2, stats.RPD, "should have 2 daily requests")
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
|
||||
func TestSQLite_MigrateYAMLToSQLite_Bad(t *testing.T) {
|
||||
t.Run("non-existent YAML file", func(t *testing.T) {
|
||||
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", filepath.Join(t.TempDir(), "out.db"))
|
||||
err := MigrateYAMLToSQLite("/nonexistent/state.yaml", testPath(t.TempDir(), "out.db"))
|
||||
assert.Error(t, err, "should fail with non-existent YAML file")
|
||||
})
|
||||
|
||||
t.Run("corrupt YAML file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "corrupt.yaml")
|
||||
require.NoError(t, os.WriteFile(yamlPath, []byte("{{{{not yaml!"), 0644))
|
||||
yamlPath := testPath(tmpDir, "corrupt.yaml")
|
||||
writeTestFile(t, yamlPath, "{{{{not yaml!")
|
||||
|
||||
err := MigrateYAMLToSQLite(yamlPath, filepath.Join(tmpDir, "out.db"))
|
||||
err := MigrateYAMLToSQLite(yamlPath, testPath(tmpDir, "out.db"))
|
||||
assert.Error(t, err, "should fail with corrupt YAML")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||
func TestSQLite_MigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "atomic.db")
|
||||
yamlPath := testPath(tmpDir, "atomic.yaml")
|
||||
sqlitePath := testPath(tmpDir, "atomic.db")
|
||||
now := time.Now().UTC()
|
||||
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
|
|
@ -573,7 +568,7 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
|||
}
|
||||
data, err := yaml.Marshal(migrated)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
|
||||
writeTestFile(t, yamlPath, string(data))
|
||||
|
||||
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
|
||||
require.Error(t, err)
|
||||
|
|
@ -594,10 +589,10 @@ func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
|||
assert.NotContains(t, state, "new-model")
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||
func TestSQLite_MigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "full.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "full.db")
|
||||
yamlPath := testPath(tmpDir, "full.yaml")
|
||||
sqlitePath := testPath(tmpDir, "full.db")
|
||||
|
||||
// Create a full YAML state with all Gemini models.
|
||||
rl, err := New()
|
||||
|
|
@ -626,12 +621,12 @@ func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: Corrupt DB recovery ---
|
||||
|
||||
func TestSQLiteCorruptDB_Ugly(t *testing.T) {
|
||||
func TestSQLite_CorruptDB_Ugly(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "corrupt.db")
|
||||
dbPath := testPath(tmpDir, "corrupt.db")
|
||||
|
||||
// Write garbage to the DB file.
|
||||
require.NoError(t, os.WriteFile(dbPath, []byte("THIS IS NOT A SQLITE DATABASE"), 0644))
|
||||
writeTestFile(t, dbPath, "THIS IS NOT A SQLITE DATABASE")
|
||||
|
||||
// Opening a corrupt DB may succeed (sqlite is lazy about validation),
|
||||
// but operations on it should fail gracefully.
|
||||
|
|
@ -648,9 +643,9 @@ func TestSQLiteCorruptDB_Ugly(t *testing.T) {
|
|||
assert.Error(t, err, "loading from corrupt DB should return an error")
|
||||
}
|
||||
|
||||
func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
|
||||
func TestSQLite_TruncatedDB_Ugly(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dbPath := filepath.Join(tmpDir, "truncated.db")
|
||||
dbPath := testPath(tmpDir, "truncated.db")
|
||||
|
||||
// Create a valid DB first.
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
|
|
@ -661,11 +656,7 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
|
|||
require.NoError(t, store.close())
|
||||
|
||||
// Truncate the file to simulate corruption.
|
||||
f, err := os.OpenFile(dbPath, os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
require.NoError(t, err)
|
||||
_, err = f.Write([]byte("TRUNC"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
overwriteTestFile(t, dbPath, "TRUNC")
|
||||
|
||||
// Opening should either fail or operations should fail.
|
||||
store2, err := newSQLiteStore(dbPath)
|
||||
|
|
@ -679,9 +670,9 @@ func TestSQLiteTruncatedDB_Ugly(t *testing.T) {
|
|||
assert.Error(t, err, "loading from truncated DB should return an error")
|
||||
}
|
||||
|
||||
func TestSQLiteEmptyModelState_Good(t *testing.T) {
|
||||
func TestSQLite_EmptyModelState_Good(t *testing.T) {
|
||||
// State with no requests or tokens but with a daily counter.
|
||||
dbPath := filepath.Join(t.TempDir(), "empty-state.db")
|
||||
dbPath := testPath(t.TempDir(), "empty-state.db")
|
||||
store, err := newSQLiteStore(dbPath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
|
@ -708,8 +699,8 @@ func TestSQLiteEmptyModelState_Good(t *testing.T) {
|
|||
|
||||
// --- Phase 2: End-to-end with persist cycle ---
|
||||
|
||||
func TestSQLiteEndToEnd_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "e2e.db")
|
||||
func TestSQLite_EndToEnd_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "e2e.db")
|
||||
|
||||
// Session 1: Create limiter, record usage, persist.
|
||||
rl1, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
|
|
@ -752,8 +743,8 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
|
|||
assert.Equal(t, 5, custom.MaxRPM)
|
||||
}
|
||||
|
||||
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "replace.db")
|
||||
func TestSQLite_LoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "replace.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
|
|
@ -788,8 +779,8 @@ func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
|||
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
|
||||
}
|
||||
|
||||
func TestSQLitePersistAtomic_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
|
||||
func TestSQLite_PersistAtomic_Good(t *testing.T) {
|
||||
dbPath := testPath(t.TempDir(), "persist-atomic.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
|
|
@ -827,7 +818,7 @@ func TestSQLitePersistAtomic_Good(t *testing.T) {
|
|||
// --- Phase 2: Benchmark ---
|
||||
|
||||
func BenchmarkSQLitePersist(b *testing.B) {
|
||||
dbPath := filepath.Join(b.TempDir(), "bench.db")
|
||||
dbPath := testPath(b.TempDir(), "bench.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
@ -852,7 +843,7 @@ func BenchmarkSQLitePersist(b *testing.B) {
|
|||
}
|
||||
|
||||
func BenchmarkSQLiteLoad(b *testing.B) {
|
||||
dbPath := filepath.Join(b.TempDir(), "bench-load.db")
|
||||
dbPath := testPath(b.TempDir(), "bench-load.db")
|
||||
rl, err := NewWithSQLite(dbPath)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
@ -883,10 +874,10 @@ func BenchmarkSQLiteLoad(b *testing.B) {
|
|||
|
||||
// TestMigrateYAMLToSQLiteWithFullState tests migration of a realistic YAML
|
||||
// file that contains the full serialised RateLimiter struct.
|
||||
func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
||||
func TestSQLite_MigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "realistic.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "realistic.db")
|
||||
yamlPath := testPath(tmpDir, "realistic.yaml")
|
||||
sqlitePath := testPath(tmpDir, "realistic.db")
|
||||
|
||||
now := time.Now()
|
||||
|
||||
|
|
@ -919,7 +910,7 @@ func TestMigrateYAMLToSQLiteWithFullState_Good(t *testing.T) {
|
|||
|
||||
data, err := yaml.Marshal(rl)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(yamlPath, data, 0644))
|
||||
writeTestFile(t, yamlPath, string(data))
|
||||
|
||||
// Migrate.
|
||||
require.NoError(t, MigrateYAMLToSQLite(yamlPath, sqlitePath))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue