[agent/codex] Fix ALL findings from issue #3. Read CLAUDE.md first. Token ... #4
8 changed files with 782 additions and 143 deletions
54
.github/workflows/ci.yml
vendored
Normal file
54
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, dev]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
if: github.event_name != 'pull_request_review'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: dAppCore/build/actions/build/core@dev
|
||||
with:
|
||||
go-version: "1.26"
|
||||
run-vet: "true"
|
||||
|
||||
auto-fix:
|
||||
if: >
|
||||
github.event_name == 'pull_request_review' &&
|
||||
github.event.review.user.login == 'coderabbitai' &&
|
||||
github.event.review.state == 'changes_requested'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
fetch-depth: 0
|
||||
- uses: dAppCore/build/actions/fix@dev
|
||||
with:
|
||||
go-version: "1.26"
|
||||
|
||||
auto-merge:
|
||||
if: >
|
||||
github.event_name == 'pull_request_review' &&
|
||||
github.event.review.user.login == 'coderabbitai' &&
|
||||
github.event.review.state == 'approved'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Merge PR
|
||||
run: gh pr merge ${{ github.event.pull_request.number }} --merge --delete-branch
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
7
go.mod
7
go.mod
|
|
@ -3,13 +3,14 @@ module forge.lthn.ai/core/go-ratelimit
|
|||
go 1.26.0
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/go-io v0.1.5
|
||||
forge.lthn.ai/core/go-log v0.0.4
|
||||
dappco.re/go/core/io v0.2.0
|
||||
dappco.re/go/core/log v0.1.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
modernc.org/sqlite v1.46.2
|
||||
modernc.org/sqlite v1.47.0
|
||||
)
|
||||
|
||||
require (
|
||||
forge.lthn.ai/core/go-log v0.0.4 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
|
|
|
|||
10
go.sum
10
go.sum
|
|
@ -1,5 +1,7 @@
|
|||
forge.lthn.ai/core/go-io v0.1.5 h1:+XJ1YhaGGFLGtcNbPtVlndTjk+pO0Ydi2hRDj5/cHOM=
|
||||
forge.lthn.ai/core/go-io v0.1.5/go.mod h1:FRtXSsi8W+U9vewCU+LBAqqbIj3wjXA4dBdSv3SAtWI=
|
||||
dappco.re/go/core/io v0.2.0 h1:zuudgIiTsQQ5ipVt97saWdGLROovbEB/zdVyy9/l+I4=
|
||||
dappco.re/go/core/io v0.2.0/go.mod h1:1QnQV6X9LNgFKfm8SkOtR9LLaj3bDcsOIeJOOyjbL5E=
|
||||
dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc=
|
||||
dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs=
|
||||
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
|
||||
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
|
|
@ -64,8 +66,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE=
|
||||
modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
|
||||
modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
|
|
|
|||
39
iter_test.go
39
iter_test.go
|
|
@ -31,13 +31,19 @@ func TestIterators(t *testing.T) {
|
|||
assert.Contains(t, models, "model-a")
|
||||
assert.Contains(t, models, "model-b")
|
||||
assert.Contains(t, models, "model-c")
|
||||
|
||||
|
||||
// Check sorting of our specific models
|
||||
foundA, foundB, foundC := -1, -1, -1
|
||||
for i, m := range models {
|
||||
if m == "model-a" { foundA = i }
|
||||
if m == "model-b" { foundB = i }
|
||||
if m == "model-c" { foundC = i }
|
||||
if m == "model-a" {
|
||||
foundA = i
|
||||
}
|
||||
if m == "model-b" {
|
||||
foundB = i
|
||||
}
|
||||
if m == "model-c" {
|
||||
foundC = i
|
||||
}
|
||||
}
|
||||
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
|
||||
})
|
||||
|
|
@ -57,9 +63,15 @@ func TestIterators(t *testing.T) {
|
|||
// Check sorting
|
||||
foundA, foundB, foundC := -1, -1, -1
|
||||
for i, m := range models {
|
||||
if m == "model-a" { foundA = i }
|
||||
if m == "model-b" { foundB = i }
|
||||
if m == "model-c" { foundC = i }
|
||||
if m == "model-a" {
|
||||
foundA = i
|
||||
}
|
||||
if m == "model-b" {
|
||||
foundB = i
|
||||
}
|
||||
if m == "model-c" {
|
||||
foundC = i
|
||||
}
|
||||
}
|
||||
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
|
||||
})
|
||||
|
|
@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCountTokensFull(t *testing.T) {
|
||||
t.Run("invalid URL/network error", func(t *testing.T) {
|
||||
// Using an invalid character in model name to trigger URL error or similar
|
||||
_, err := CountTokens(context.Background(), "key", "invalid model", "text")
|
||||
t.Run("empty model is rejected", func(t *testing.T) {
|
||||
_, err := CountTokens(context.Background(), "key", "", "text")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
|
||||
// We can't easily override the URL in CountTokens without changing the code,
|
||||
// but we can test the logic if we make it slightly more testable.
|
||||
// For now, I've already updated ratelimit_test.go with some of this.
|
||||
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "status 400")
|
||||
})
|
||||
|
||||
t.Run("context cancelled", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
_, err := CountTokens(ctx, "key", "model", "text")
|
||||
_, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "do request")
|
||||
})
|
||||
|
|
|
|||
291
ratelimit.go
291
ratelimit.go
|
|
@ -9,14 +9,16 @@ import (
|
|||
"iter"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
coreio "forge.lthn.ai/core/go-io"
|
||||
coreio "dappco.re/go/core/io"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
|
|
@ -34,6 +36,16 @@ const (
|
|||
ProviderLocal Provider = "local"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultStateDirName = ".core"
|
||||
defaultYAMLStateFile = "ratelimits.yaml"
|
||||
defaultSQLiteStateFile = "ratelimits.db"
|
||||
backendYAML = "yaml"
|
||||
backendSQLite = "sqlite"
|
||||
countTokensErrorBodyLimit = 8 * 1024
|
||||
countTokensSuccessBodyLimit = 1 * 1024 * 1024
|
||||
)
|
||||
|
||||
// ModelQuota defines the rate limits for a specific model.
|
||||
type ModelQuota struct {
|
||||
MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited)
|
||||
|
|
@ -143,39 +155,30 @@ func New() (*RateLimiter, error) {
|
|||
// NewWithConfig creates a RateLimiter from explicit configuration.
|
||||
// If no providers or quotas are specified, Gemini defaults are used.
|
||||
func NewWithConfig(cfg Config) (*RateLimiter, error) {
|
||||
backend, err := normaliseBackend(cfg.Backend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filePath := cfg.FilePath
|
||||
if filePath == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
filePath, err = defaultStatePath(backend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filePath = filepath.Join(home, ".core", "ratelimits.yaml")
|
||||
}
|
||||
|
||||
rl := &RateLimiter{
|
||||
Quotas: make(map[string]ModelQuota),
|
||||
State: make(map[string]*UsageStats),
|
||||
filePath: filePath,
|
||||
}
|
||||
|
||||
// Load provider profiles
|
||||
profiles := DefaultProfiles()
|
||||
providers := cfg.Providers
|
||||
|
||||
// If nothing specified at all, default to Gemini
|
||||
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
||||
providers = []Provider{ProviderGemini}
|
||||
}
|
||||
|
||||
for _, p := range providers {
|
||||
if profile, ok := profiles[p]; ok {
|
||||
maps.Copy(rl.Quotas, profile.Models)
|
||||
if backend == backendSQLite {
|
||||
if cfg.FilePath == "" {
|
||||
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
||||
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err)
|
||||
}
|
||||
}
|
||||
return NewWithSQLiteConfig(filePath, cfg)
|
||||
}
|
||||
|
||||
// Merge explicit quotas on top (allows overrides)
|
||||
maps.Copy(rl.Quotas, cfg.Quotas)
|
||||
|
||||
rl := newConfiguredRateLimiter(cfg)
|
||||
rl.filePath = filePath
|
||||
return rl, nil
|
||||
}
|
||||
|
||||
|
|
@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Merge loaded quotas (loaded quotas override in-memory defaults).
|
||||
maps.Copy(rl.Quotas, quotas)
|
||||
// Persisted quotas are authoritative when present; otherwise keep config defaults.
|
||||
if len(quotas) > 0 {
|
||||
rl.Quotas = maps.Clone(quotas)
|
||||
}
|
||||
|
||||
state, err := rl.sqlite.loadState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Replace in-memory state with persisted state.
|
||||
maps.Copy(rl.State, state)
|
||||
// Replace in-memory state with the persisted snapshot.
|
||||
rl.State = state
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error {
|
|||
quotas := maps.Clone(rl.Quotas)
|
||||
state := make(map[string]*UsageStats, len(rl.State))
|
||||
for k, v := range rl.State {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
state[k] = &UsageStats{
|
||||
Requests: slices.Clone(v.Requests),
|
||||
Tokens: slices.Clone(v.Tokens),
|
||||
|
|
@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error {
|
|||
rl.mu.Unlock()
|
||||
|
||||
if sqlite != nil {
|
||||
if err := sqlite.saveQuotas(quotas); err != nil {
|
||||
return coreerr.E("ratelimit.Persist", "sqlite quotas", err)
|
||||
}
|
||||
if err := sqlite.saveState(state); err != nil {
|
||||
return coreerr.E("ratelimit.Persist", "sqlite state", err)
|
||||
if err := sqlite.saveSnapshot(quotas, state); err != nil {
|
||||
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) {
|
|||
if !ok {
|
||||
return
|
||||
}
|
||||
if stats == nil {
|
||||
delete(rl.State, model)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
window := now.Add(-1 * time.Minute)
|
||||
|
|
@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) {
|
|||
// BackgroundPrune starts a goroutine that periodically prunes all model states.
|
||||
// It returns a function to stop the pruner.
|
||||
func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
||||
if interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
|
|
@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() {
|
|||
|
||||
// CanSend checks if a request can be sent without violating limits.
|
||||
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||
if estimatedTokens < 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
|
|
@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
|||
|
||||
// Check TPM
|
||||
if quota.MaxTPM > 0 {
|
||||
currentTokens := 0
|
||||
for _, t := range stats.Tokens {
|
||||
currentTokens += t.Count
|
||||
}
|
||||
if currentTokens+estimatedTokens > quota.MaxTPM {
|
||||
currentTokens := totalTokenCount(stats.Tokens)
|
||||
if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
|
|||
}
|
||||
|
||||
now := time.Now()
|
||||
totalTokens := safeTokenSum(promptTokens, outputTokens)
|
||||
stats.Requests = append(stats.Requests, now)
|
||||
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens})
|
||||
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens})
|
||||
stats.DayCount++
|
||||
}
|
||||
|
||||
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||
if tokens < 0 {
|
||||
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
|
@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
|||
result[m] = ModelStats{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
window := now.Add(-1 * time.Minute)
|
||||
|
||||
for m := range result {
|
||||
// Prune inline
|
||||
if s, ok := rl.State[m]; ok {
|
||||
s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool {
|
||||
return !t.After(window)
|
||||
})
|
||||
s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool {
|
||||
return !t.Time.After(window)
|
||||
})
|
||||
|
||||
if now.Sub(s.DayStart) >= 24*time.Hour {
|
||||
s.DayStart = now
|
||||
s.DayCount = 0
|
||||
}
|
||||
}
|
||||
rl.prune(m)
|
||||
|
||||
ms := ModelStats{}
|
||||
if q, ok := rl.Quotas[m]; ok {
|
||||
|
|
@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
|||
ms.MaxTPM = q.MaxTPM
|
||||
ms.MaxRPD = q.MaxRPD
|
||||
}
|
||||
if s, ok := rl.State[m]; ok {
|
||||
if s, ok := rl.State[m]; ok && s != nil {
|
||||
ms.RPM = len(s.Requests)
|
||||
ms.RPD = s.DayCount
|
||||
ms.DayStart = s.DayStart
|
||||
for _, t := range s.Tokens {
|
||||
ms.TPM += t.Count
|
||||
}
|
||||
ms.TPM = totalTokenCount(s.Tokens)
|
||||
}
|
||||
result[m] = ms
|
||||
}
|
||||
|
|
@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
rl := &RateLimiter{
|
||||
Quotas: make(map[string]ModelQuota),
|
||||
State: make(map[string]*UsageStats),
|
||||
sqlite: store,
|
||||
}
|
||||
|
||||
// Load provider profiles.
|
||||
profiles := DefaultProfiles()
|
||||
providers := cfg.Providers
|
||||
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
||||
providers = []Provider{ProviderGemini}
|
||||
}
|
||||
for _, p := range providers {
|
||||
if profile, ok := profiles[p]; ok {
|
||||
maps.Copy(rl.Quotas, profile.Models)
|
||||
}
|
||||
}
|
||||
maps.Copy(rl.Quotas, cfg.Quotas)
|
||||
|
||||
rl := newConfiguredRateLimiter(cfg)
|
||||
rl.sqlite = store
|
||||
return rl, nil
|
||||
}
|
||||
|
||||
|
|
@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
|
|||
}
|
||||
defer store.close()
|
||||
|
||||
if rl.Quotas != nil {
|
||||
if err := store.saveQuotas(rl.Quotas); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if rl.State != nil {
|
||||
if err := store.saveState(rl.State); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountTokens calls the Google API to count tokens for a prompt.
|
||||
func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
||||
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model)
|
||||
return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text)
|
||||
}
|
||||
|
||||
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
|
||||
requestURL, err := countTokensURL(baseURL, model)
|
||||
if err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "build url", err)
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
"contents": []any{
|
||||
|
|
@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
|
|||
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "new request", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if client == nil {
|
||||
client = http.DefaultClient
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "do request", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil)
|
||||
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
|
||||
if err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err)
|
||||
}
|
||||
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil {
|
||||
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err)
|
||||
}
|
||||
|
||||
return result.TotalTokens, nil
|
||||
}
|
||||
|
||||
func newConfiguredRateLimiter(cfg Config) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
Quotas: make(map[string]ModelQuota),
|
||||
State: make(map[string]*UsageStats),
|
||||
}
|
||||
applyConfig(rl, cfg)
|
||||
return rl
|
||||
}
|
||||
|
||||
func applyConfig(rl *RateLimiter, cfg Config) {
|
||||
profiles := DefaultProfiles()
|
||||
providers := cfg.Providers
|
||||
|
||||
if len(providers) == 0 && len(cfg.Quotas) == 0 {
|
||||
providers = []Provider{ProviderGemini}
|
||||
}
|
||||
|
||||
for _, p := range providers {
|
||||
if profile, ok := profiles[p]; ok {
|
||||
maps.Copy(rl.Quotas, profile.Models)
|
||||
}
|
||||
}
|
||||
|
||||
maps.Copy(rl.Quotas, cfg.Quotas)
|
||||
}
|
||||
|
||||
func normaliseBackend(backend string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(backend)) {
|
||||
case "", backendYAML:
|
||||
return backendYAML, nil
|
||||
case backendSQLite:
|
||||
return backendSQLite, nil
|
||||
default:
|
||||
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil)
|
||||
}
|
||||
}
|
||||
|
||||
func defaultStatePath(backend string) (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := defaultYAMLStateFile
|
||||
if backend == backendSQLite {
|
||||
fileName = defaultSQLiteStateFile
|
||||
}
|
||||
|
||||
return filepath.Join(home, defaultStateDirName, fileName), nil
|
||||
}
|
||||
|
||||
func safeTokenSum(a, b int) int {
|
||||
return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}})
|
||||
}
|
||||
|
||||
func totalTokenCount(tokens []TokenEntry) int {
|
||||
return safeTokenTotal(tokens)
|
||||
}
|
||||
|
||||
func safeTokenTotal(tokens []TokenEntry) int {
|
||||
const maxInt = int(^uint(0) >> 1)
|
||||
|
||||
total := 0
|
||||
for _, token := range tokens {
|
||||
count := token.Count
|
||||
if count < 0 {
|
||||
continue
|
||||
}
|
||||
if total > maxInt-count {
|
||||
return maxInt
|
||||
}
|
||||
total += count
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func countTokensURL(baseURL, model string) (string, error) {
|
||||
if strings.TrimSpace(model) == "" {
|
||||
return "", fmt.Errorf("empty model")
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed.Scheme == "" || parsed.Host == "" {
|
||||
return "", fmt.Errorf("invalid base url")
|
||||
}
|
||||
|
||||
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil
|
||||
}
|
||||
|
||||
func readLimitedBody(r io.Reader, limit int64) (string, error) {
|
||||
body, err := io.ReadAll(io.LimitReader(r, limit+1))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
truncated := int64(len(body)) > limit
|
||||
if truncated {
|
||||
body = body[:limit]
|
||||
}
|
||||
|
||||
result := string(body)
|
||||
if truncated {
|
||||
result += "..."
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter {
|
|||
return rl
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
type errReader struct{}
|
||||
|
||||
func (errReader) Read([]byte) (int, error) {
|
||||
return 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
// --- Phase 0: CanSend boundary conditions ---
|
||||
|
||||
func TestCanSend(t *testing.T) {
|
||||
|
|
@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) {
|
|||
rl.RecordUsage(model, 50, 50) // exactly 100 tokens
|
||||
assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit")
|
||||
})
|
||||
|
||||
t.Run("negative estimated tokens are rejected", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "negative-est"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100}
|
||||
rl.RecordUsage(model, 50, 50)
|
||||
|
||||
assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits")
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Sliding window / prune tests ---
|
||||
|
|
@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) {
|
|||
assert.Len(t, stats.Tokens, 2)
|
||||
assert.Equal(t, 6, stats.DayCount)
|
||||
})
|
||||
|
||||
t.Run("negative token inputs are clamped to zero", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "record-negative"
|
||||
|
||||
rl.RecordUsage(model, -100, 25)
|
||||
|
||||
stats := rl.State[model]
|
||||
require.NotNil(t, stats)
|
||||
assert.Len(t, stats.Requests, 1)
|
||||
assert.Len(t, stats.Tokens, 1)
|
||||
assert.Equal(t, 25, stats.Tokens[0].Count)
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Reset ---
|
||||
|
|
@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) {
|
|||
err := rl.WaitForCapacity(ctx, model, 10)
|
||||
assert.Error(t, err, "should return error for already-cancelled context")
|
||||
})
|
||||
|
||||
t.Run("negative tokens return error", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
err := rl.WaitForCapacity(context.Background(), "wait-negative", -1)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "negative tokens")
|
||||
})
|
||||
}
|
||||
|
||||
func TestNilUsageStats(t *testing.T) {
|
||||
t.Run("CanSend replaces nil state without panicking", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "nil-cansend"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10}
|
||||
rl.State[model] = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
assert.True(t, rl.CanSend(model, 10))
|
||||
})
|
||||
assert.NotNil(t, rl.State[model])
|
||||
})
|
||||
|
||||
t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
model := "nil-record"
|
||||
rl.State[model] = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
rl.RecordUsage(model, 10, 10)
|
||||
})
|
||||
require.NotNil(t, rl.State[model])
|
||||
assert.Equal(t, 1, rl.State[model].DayCount)
|
||||
})
|
||||
|
||||
t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}
|
||||
rl.State["nil-stats"] = nil
|
||||
rl.State["nil-all-stats"] = nil
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
stats := rl.Stats("nil-stats")
|
||||
assert.Equal(t, 1, stats.MaxRPM)
|
||||
assert.Equal(t, 0, stats.TPM)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
all := rl.AllStats()
|
||||
assert.Contains(t, all, "nil-stats")
|
||||
assert.Contains(t, all, "nil-all-stats")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: Stats ---
|
||||
|
|
@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) {
|
|||
_, exists := rl.State[model]
|
||||
return !exists
|
||||
}, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned")
|
||||
|
||||
t.Run("non-positive interval is a safe no-op", func(t *testing.T) {
|
||||
rl := newTestLimiter(t)
|
||||
rl.State["still-here"] = &UsageStats{
|
||||
Requests: []time.Time{time.Now().Add(-2 * time.Minute)},
|
||||
}
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
stop := rl.BackgroundPrune(0)
|
||||
stop()
|
||||
})
|
||||
assert.Contains(t, rl.State, "still-here")
|
||||
})
|
||||
}
|
||||
|
||||
// --- Phase 0: CountTokens (with mock HTTP server) ---
|
||||
|
|
@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) {
|
|||
assert.Equal(t, http.MethodPost, r.Method)
|
||||
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key"))
|
||||
assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath())
|
||||
|
||||
var body struct {
|
||||
Contents []struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
}
|
||||
require.NoError(t, json.NewDecoder(r.Body).Decode(&body))
|
||||
require.Len(t, body.Contents, 1)
|
||||
require.Len(t, body.Contents[0].Parts, 1)
|
||||
assert.Equal(t, "hello", body.Contents[0].Parts[0].Text)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
// For testing purposes, we would need to make the base URL configurable.
|
||||
// Since we're just checking the signature and basic logic, we test the error paths.
|
||||
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 42, tokens)
|
||||
})
|
||||
|
||||
t.Run("API error returns error", func(t *testing.T) {
|
||||
t.Run("model name is path escaped", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath())
|
||||
assert.Empty(t, r.URL.RawQuery)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 7, tokens)
|
||||
})
|
||||
|
||||
t.Run("API error body is truncated", func(t *testing.T) {
|
||||
largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
fmt.Fprint(w, `{"error": "invalid API key"}`)
|
||||
_, err := fmt.Fprint(w, largeBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := CountTokens(context.Background(), "fake-key", "test-model", "hello")
|
||||
assert.Error(t, err, "should fail with invalid API endpoint")
|
||||
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "api error status 401")
|
||||
assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded")
|
||||
assert.Contains(t, err.Error(), "...")
|
||||
})
|
||||
|
||||
t.Run("empty model is rejected before request", func(t *testing.T) {
|
||||
_, err := CountTokens(context.Background(), "fake-key", "", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "build url")
|
||||
})
|
||||
|
||||
t.Run("invalid base URL returns error", func(t *testing.T) {
|
||||
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "build url")
|
||||
})
|
||||
|
||||
t.Run("base URL without host returns error", func(t *testing.T) {
|
||||
_, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "build url")
|
||||
})
|
||||
|
||||
t.Run("invalid JSON response returns error", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err := w.Write([]byte(`{"totalTokens":`))
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decode response")
|
||||
})
|
||||
|
||||
t.Run("error body read failures are returned", func(t *testing.T) {
|
||||
client := &http.Client{
|
||||
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Body: io.NopCloser(errReader{}),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}),
|
||||
}
|
||||
|
||||
_, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "read error body")
|
||||
})
|
||||
|
||||
t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11}))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
originalClient := http.DefaultClient
|
||||
http.DefaultClient = server.Client()
|
||||
defer func() {
|
||||
http.DefaultClient = originalClient
|
||||
}()
|
||||
|
||||
tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 11, tokens)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistSkipsNilState(t *testing.T) {
|
||||
path := filepath.Join(t.TempDir(), "nil-state.yaml")
|
||||
|
||||
rl, err := New()
|
||||
require.NoError(t, err)
|
||||
rl.filePath = path
|
||||
rl.State["nil-model"] = nil
|
||||
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
rl2, err := New()
|
||||
require.NoError(t, err)
|
||||
rl2.filePath = path
|
||||
require.NoError(t, rl2.Load())
|
||||
assert.NotContains(t, rl2.State, "nil-model")
|
||||
}
|
||||
|
||||
func TestTokenTotals(t *testing.T) {
|
||||
maxInt := int(^uint(0) >> 1)
|
||||
|
||||
assert.Equal(t, 25, safeTokenSum(-100, 25))
|
||||
assert.Equal(t, maxInt, safeTokenSum(maxInt, 1))
|
||||
assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}}))
|
||||
}
|
||||
|
||||
// --- Phase 0: Benchmarks ---
|
||||
|
|
@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) {
|
|||
assert.Equal(t, 5, q.MaxRPM)
|
||||
assert.Equal(t, 50000, q.MaxTPM)
|
||||
})
|
||||
|
||||
t.Run("invalid backend returns error", func(t *testing.T) {
|
||||
_, err := NewWithConfig(Config{
|
||||
Backend: "bogus",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unknown backend")
|
||||
})
|
||||
|
||||
t.Run("default YAML path uses home directory", func(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", "")
|
||||
t.Setenv("home", "")
|
||||
|
||||
rl, err := NewWithConfig(Config{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewBackwardCompatibility(t *testing.T) {
|
||||
|
|
|
|||
93
sqlite.go
93
sqlite.go
|
|
@ -5,7 +5,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
coreerr "forge.lthn.ai/core/go-log"
|
||||
coreerr "dappco.re/go/core/log"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
|
|
@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// saveQuotas upserts all quotas into the quotas table.
|
||||
// saveQuotas writes a complete quota snapshot to the quotas table.
|
||||
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
|
|
@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(model) DO UPDATE SET
|
||||
max_rpm = excluded.max_rpm,
|
||||
max_tpm = excluded.max_tpm,
|
||||
max_rpd = excluded.max_rpd`)
|
||||
if err != nil {
|
||||
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for model, q := range quotas {
|
||||
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
||||
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||
return coreerr.E("ratelimit.saveQuotas", "clear", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
if err := insertQuotas(tx, quotas); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return commitTx(tx, "ratelimit.saveQuotas")
|
||||
}
|
||||
|
||||
// loadQuotas reads all rows from the quotas table.
|
||||
|
|
@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// saveSnapshot writes quotas and state as a single atomic snapshot.
|
||||
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
|
||||
tx, err := s.db.Begin()
|
||||
if err != nil {
|
||||
return coreerr.E("ratelimit.saveSnapshot", "begin", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if err := clearSnapshotTables(tx, true); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := insertQuotas(tx, quotas); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := insertState(tx, state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return commitTx(tx, "ratelimit.saveSnapshot")
|
||||
}
|
||||
|
||||
// saveState writes all usage state to SQLite in a single transaction.
|
||||
// It uses a truncate-and-insert approach for simplicity in this version,
|
||||
// but ensures atomicity via a single transaction.
|
||||
|
|
@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
|||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Clear existing state in the transaction.
|
||||
if err := clearSnapshotTables(tx, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := insertState(tx, state); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return commitTx(tx, "ratelimit.saveState")
|
||||
}
|
||||
|
||||
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
|
||||
if includeQuotas {
|
||||
if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
|
||||
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err)
|
||||
}
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
|
||||
return coreerr.E("ratelimit.saveState", "clear requests", err)
|
||||
}
|
||||
|
|
@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
|||
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
|
||||
return coreerr.E("ratelimit.saveState", "clear daily", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
|
||||
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return coreerr.E("ratelimit.saveQuotas", "prepare", err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for model, q := range quotas {
|
||||
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
||||
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
|
||||
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
|
||||
if err != nil {
|
||||
return coreerr.E("ratelimit.saveState", "prepare requests", err)
|
||||
|
|
@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
|||
defer dayStmt.Close()
|
||||
|
||||
for model, stats := range state {
|
||||
if stats == nil {
|
||||
continue
|
||||
}
|
||||
for _, t := range stats.Requests {
|
||||
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
|
||||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err)
|
||||
|
|
@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
|||
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func commitTx(tx *sql.Tx, scope string) error {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return coreerr.E("ratelimit.saveState", "commit", err)
|
||||
return coreerr.E(scope, "commit", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
174
sqlite_test.go
174
sqlite_test.go
|
|
@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) {
|
|||
assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)")
|
||||
}
|
||||
|
||||
func TestConfigBackendSQLite_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "config-backend.db")
|
||||
rl, err := NewWithConfig(Config{
|
||||
Backend: backendSQLite,
|
||||
FilePath: dbPath,
|
||||
Quotas: map[string]ModelQuota{
|
||||
"backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
||||
require.NotNil(t, rl.sqlite)
|
||||
rl.RecordUsage("backend-model", 10, 10)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, statErr := os.Stat(dbPath)
|
||||
assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path")
|
||||
}
|
||||
|
||||
func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) {
|
||||
home := t.TempDir()
|
||||
t.Setenv("HOME", home)
|
||||
t.Setenv("USERPROFILE", "")
|
||||
t.Setenv("home", "")
|
||||
|
||||
rl, err := NewWithConfig(Config{
|
||||
Backend: backendSQLite,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer rl.Close()
|
||||
|
||||
require.NotNil(t, rl.sqlite)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile))
|
||||
assert.NoError(t, statErr, "sqlite backend should use the default home DB path")
|
||||
}
|
||||
|
||||
// --- Phase 2: MigrateYAMLToSQLite ---
|
||||
|
||||
func TestMigrateYAMLToSQLite_Good(t *testing.T) {
|
||||
|
|
@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "atomic.yaml")
|
||||
sqlitePath := filepath.Join(tmpDir, "atomic.db")
|
||||
now := time.Now().UTC()
|
||||
|
||||
store, err := newSQLiteStore(sqlitePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
originalQuotas := map[string]ModelQuota{
|
||||
"old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3},
|
||||
}
|
||||
originalState := map[string]*UsageStats{
|
||||
"old-model": {
|
||||
Requests: []time.Time{now},
|
||||
Tokens: []TokenEntry{{Time: now, Count: 9}},
|
||||
DayStart: now,
|
||||
DayCount: 1,
|
||||
},
|
||||
}
|
||||
require.NoError(t, store.saveSnapshot(originalQuotas, originalState))
|
||||
|
||||
_, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily
|
||||
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, store.close())
|
||||
|
||||
migrated := &RateLimiter{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30},
|
||||
},
|
||||
State: map[string]*UsageStats{
|
||||
"new-model": {
|
||||
Requests: []time.Time{now.Add(5 * time.Second)},
|
||||
Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}},
|
||||
DayStart: now.Add(5 * time.Second),
|
||||
DayCount: 2,
|
||||
},
|
||||
},
|
||||
}
|
||||
data, err := yaml.Marshal(migrated)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(yamlPath, data, 0o644))
|
||||
|
||||
err = MigrateYAMLToSQLite(yamlPath, sqlitePath)
|
||||
require.Error(t, err)
|
||||
|
||||
store, err = newSQLiteStore(sqlitePath)
|
||||
require.NoError(t, err)
|
||||
defer store.close()
|
||||
|
||||
quotas, err := store.loadQuotas()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalQuotas, quotas)
|
||||
|
||||
state, err := store.loadState()
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, state, "old-model")
|
||||
assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount)
|
||||
assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count)
|
||||
assert.NotContains(t, state, "new-model")
|
||||
}
|
||||
|
||||
func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
yamlPath := filepath.Join(tmpDir, "full.yaml")
|
||||
|
|
@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) {
|
|||
assert.Equal(t, 5, custom.MaxRPM)
|
||||
}
|
||||
|
||||
func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "replace.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rl.RecordUsage("model-a", 10, 10)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
delete(rl.Quotas, "model-a")
|
||||
rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
|
||||
rl.Reset("")
|
||||
rl.RecordUsage("model-b", 5, 5)
|
||||
require.NoError(t, rl.Persist())
|
||||
require.NoError(t, rl.Close())
|
||||
|
||||
rl2, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Providers: []Provider{ProviderGemini},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer rl2.Close()
|
||||
|
||||
rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99}
|
||||
require.NoError(t, rl2.Load())
|
||||
|
||||
assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview")
|
||||
assert.NotContains(t, rl2.Quotas, "model-a")
|
||||
assert.Contains(t, rl2.Quotas, "model-b")
|
||||
assert.NotContains(t, rl2.State, "stale-memory")
|
||||
assert.NotContains(t, rl2.State, "model-a")
|
||||
assert.Equal(t, 1, rl2.Stats("model-b").RPD)
|
||||
}
|
||||
|
||||
func TestSQLitePersistAtomic_Good(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "persist-atomic.db")
|
||||
rl, err := NewWithSQLiteConfig(dbPath, Config{
|
||||
Quotas: map[string]ModelQuota{
|
||||
"old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
rl.RecordUsage("old-model", 10, 10)
|
||||
require.NoError(t, rl.Persist())
|
||||
|
||||
_, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily
|
||||
BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`)
|
||||
require.NoError(t, err)
|
||||
|
||||
delete(rl.Quotas, "old-model")
|
||||
rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20}
|
||||
rl.Reset("")
|
||||
rl.RecordUsage("new-model", 50, 50)
|
||||
|
||||
err = rl.Persist()
|
||||
require.Error(t, err)
|
||||
require.NoError(t, rl.Close())
|
||||
|
||||
rl2, err := NewWithSQLiteConfig(dbPath, Config{})
|
||||
require.NoError(t, err)
|
||||
defer rl2.Close()
|
||||
require.NoError(t, rl2.Load())
|
||||
|
||||
assert.Contains(t, rl2.Quotas, "old-model")
|
||||
assert.NotContains(t, rl2.Quotas, "new-model")
|
||||
assert.Equal(t, 1, rl2.Stats("old-model").RPD)
|
||||
assert.Equal(t, 0, rl2.Stats("new-model").RPD)
|
||||
}
|
||||
|
||||
// --- Phase 2: Benchmark ---
|
||||
|
||||
func BenchmarkSQLitePersist(b *testing.B) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue