diff --git a/iter_test.go b/iter_test.go index aaa59be..f353e16 100644 --- a/iter_test.go +++ b/iter_test.go @@ -31,13 +31,19 @@ func TestIterators(t *testing.T) { assert.Contains(t, models, "model-a") assert.Contains(t, models, "model-b") assert.Contains(t, models, "model-c") - + // Check sorting of our specific models foundA, foundB, foundC := -1, -1, -1 for i, m := range models { - if m == "model-a" { foundA = i } - if m == "model-b" { foundB = i } - if m == "model-c" { foundC = i } + if m == "model-a" { + foundA = i + } + if m == "model-b" { + foundB = i + } + if m == "model-c" { + foundC = i + } } assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c") }) @@ -57,9 +63,15 @@ func TestIterators(t *testing.T) { // Check sorting foundA, foundB, foundC := -1, -1, -1 for i, m := range models { - if m == "model-a" { foundA = i } - if m == "model-b" { foundB = i } - if m == "model-c" { foundC = i } + if m == "model-a" { + foundA = i + } + if m == "model-b" { + foundB = i + } + if m == "model-c" { + foundC = i + } } assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c") }) @@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) { } func TestCountTokensFull(t *testing.T) { - t.Run("invalid URL/network error", func(t *testing.T) { - // Using an invalid character in model name to trigger URL error or similar - _, err := CountTokens(context.Background(), "key", "invalid model", "text") + t.Run("empty model is rejected", func(t *testing.T) { + _, err := CountTokens(context.Background(), "key", "", "text") assert.Error(t, err) }) @@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) { })) defer server.Close() - // We can't easily override the URL in CountTokens without changing the code, - // but we can test the logic if we make it slightly more testable. - // For now, I've already updated ratelimit_test.go with some of this. + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text") + require.Error(t, err) + assert.Contains(t, err.Error(), "status 400") }) t.Run("context cancelled", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := CountTokens(ctx, "key", "model", "text") + _, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text") assert.Error(t, err) assert.Contains(t, err.Error(), "do request") }) diff --git a/ratelimit.go b/ratelimit.go index d2d0007..0e34709 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -9,9 +9,11 @@ import ( "iter" "maps" "net/http" + "net/url" "os" "path/filepath" "slices" + "strings" "sync" "time" @@ -34,6 +36,16 @@ const ( ProviderLocal Provider = "local" ) +const ( + defaultStateDirName = ".core" + defaultYAMLStateFile = "ratelimits.yaml" + defaultSQLiteStateFile = "ratelimits.db" + backendYAML = "yaml" + backendSQLite = "sqlite" + countTokensErrorBodyLimit = 8 * 1024 + countTokensSuccessBodyLimit = 1 * 1024 * 1024 +) + // ModelQuota defines the rate limits for a specific model. type ModelQuota struct { MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited) @@ -143,39 +155,30 @@ func New() (*RateLimiter, error) { // NewWithConfig creates a RateLimiter from explicit configuration. // If no providers or quotas are specified, Gemini defaults are used. func NewWithConfig(cfg Config) (*RateLimiter, error) { + backend, err := normaliseBackend(cfg.Backend) + if err != nil { + return nil, err + } + filePath := cfg.FilePath if filePath == "" { - home, err := os.UserHomeDir() + filePath, err = defaultStatePath(backend) if err != nil { return nil, err } - filePath = filepath.Join(home, ".core", "ratelimits.yaml") } - rl := &RateLimiter{ - Quotas: make(map[string]ModelQuota), - State: make(map[string]*UsageStats), - filePath: filePath, - } - - // Load provider profiles - profiles := DefaultProfiles() - providers := cfg.Providers - - // If nothing specified at all, default to Gemini - if len(providers) == 0 && len(cfg.Quotas) == 0 { - providers = []Provider{ProviderGemini} - } - - for _, p := range providers { - if profile, ok := profiles[p]; ok { - maps.Copy(rl.Quotas, profile.Models) + if backend == backendSQLite { + if cfg.FilePath == "" { + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err) + } } + return NewWithSQLiteConfig(filePath, cfg) } - // Merge explicit quotas on top (allows overrides) - maps.Copy(rl.Quotas, cfg.Quotas) - + rl := newConfiguredRateLimiter(cfg) + rl.filePath = filePath return rl, nil } @@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error { if err != nil { return err } - // Merge loaded quotas (loaded quotas override in-memory defaults). - maps.Copy(rl.Quotas, quotas) + // Persisted quotas are authoritative when present; otherwise keep config defaults. + if len(quotas) > 0 { + rl.Quotas = maps.Clone(quotas) + } state, err := rl.sqlite.loadState() if err != nil { return err } - // Replace in-memory state with persisted state. - maps.Copy(rl.State, state) + // Replace in-memory state with the persisted snapshot. + rl.State = state return nil } @@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error { quotas := maps.Clone(rl.Quotas) state := make(map[string]*UsageStats, len(rl.State)) for k, v := range rl.State { + if v == nil { + continue + } state[k] = &UsageStats{ Requests: slices.Clone(v.Requests), Tokens: slices.Clone(v.Tokens), @@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error { rl.mu.Unlock() if sqlite != nil { - if err := sqlite.saveQuotas(quotas); err != nil { - return coreerr.E("ratelimit.Persist", "sqlite quotas", err) - } - if err := sqlite.saveState(state); err != nil { - return coreerr.E("ratelimit.Persist", "sqlite state", err) + if err := sqlite.saveSnapshot(quotas, state); err != nil { + return coreerr.E("ratelimit.Persist", "sqlite snapshot", err) } return nil } @@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) { if !ok { return } + if stats == nil { + delete(rl.State, model) + return + } now := time.Now() window := now.Add(-1 * time.Minute) @@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) { // BackgroundPrune starts a goroutine that periodically prunes all model states. // It returns a function to stop the pruner. func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { + if interval <= 0 { + return func() {} + } + ctx, cancel := context.WithCancel(context.Background()) go func() { ticker := time.NewTicker(interval) @@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { // CanSend checks if a request can be sent without violating limits. func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { + if estimatedTokens < 0 { + return false + } + rl.mu.Lock() defer rl.mu.Unlock() @@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { // Check TPM if quota.MaxTPM > 0 { - currentTokens := 0 - for _, t := range stats.Tokens { - currentTokens += t.Count - } - if currentTokens+estimatedTokens > quota.MaxTPM { + currentTokens := totalTokenCount(stats.Tokens) + if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens { return false } } @@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) } now := time.Now() + totalTokens := safeTokenSum(promptTokens, outputTokens) stats.Requests = append(stats.Requests, now) - stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) + stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens}) stats.DayCount++ } // WaitForCapacity blocks until capacity is available or context is cancelled. func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { + if tokens < 0 { + return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil) + } + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { result[m] = ModelStats{} } - now := time.Now() - window := now.Add(-1 * time.Minute) - for m := range result { - // Prune inline - if s, ok := rl.State[m]; ok { - s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool { - return !t.After(window) - }) - s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool { - return !t.Time.After(window) - }) - - if now.Sub(s.DayStart) >= 24*time.Hour { - s.DayStart = now - s.DayCount = 0 - } - } + rl.prune(m) ms := ModelStats{} if q, ok := rl.Quotas[m]; ok { @@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { ms.MaxTPM = q.MaxTPM ms.MaxRPD = q.MaxRPD } - if s, ok := rl.State[m]; ok { + if s, ok := rl.State[m]; ok && s != nil { ms.RPM = len(s.Requests) ms.RPD = s.DayCount ms.DayStart = s.DayStart - for _, t := range s.Tokens { - ms.TPM += t.Count - } + ms.TPM = totalTokenCount(s.Tokens) } result[m] = ms } @@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) { return nil, err } - rl := &RateLimiter{ - Quotas: make(map[string]ModelQuota), - State: make(map[string]*UsageStats), - sqlite: store, - } - - // Load provider profiles. - profiles := DefaultProfiles() - providers := cfg.Providers - if len(providers) == 0 && len(cfg.Quotas) == 0 { - providers = []Provider{ProviderGemini} - } - for _, p := range providers { - if profile, ok := profiles[p]; ok { - maps.Copy(rl.Quotas, profile.Models) - } - } - maps.Copy(rl.Quotas, cfg.Quotas) - + rl := newConfiguredRateLimiter(cfg) + rl.sqlite = store return rl, nil } @@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { } defer store.close() - if rl.Quotas != nil { - if err := store.saveQuotas(rl.Quotas); err != nil { - return err - } - } - if rl.State != nil { - if err := store.saveState(rl.State); err != nil { - return err - } + if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil { + return err } return nil } // CountTokens calls the Google API to count tokens for a prompt. func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { - url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model) + return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text) +} + +func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) { + requestURL, err := countTokensURL(baseURL, model) + if err != nil { + return 0, coreerr.E("ratelimit.CountTokens", "build url", err) + } reqBody := map[string]any{ "contents": []any{ @@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody)) if err != nil { return 0, coreerr.E("ratelimit.CountTokens", "new request", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("x-goog-api-key", apiKey) - resp, err := http.DefaultClient.Do(req) + if client == nil { + client = http.DefaultClient + } + + resp, err := client.Do(req) if err != nil { return 0, coreerr.E("ratelimit.CountTokens", "do request", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil) + body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit) + if err != nil { + return 0, coreerr.E("ratelimit.CountTokens", "read error body", err) + } + return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil) } var result struct { TotalTokens int `json:"totalTokens"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil { return 0, coreerr.E("ratelimit.CountTokens", "decode response", err) } return result.TotalTokens, nil } + +func newConfiguredRateLimiter(cfg Config) *RateLimiter { + rl := &RateLimiter{ + Quotas: make(map[string]ModelQuota), + State: make(map[string]*UsageStats), + } + applyConfig(rl, cfg) + return rl +} + +func applyConfig(rl *RateLimiter, cfg Config) { + profiles := DefaultProfiles() + providers := cfg.Providers + + if len(providers) == 0 && len(cfg.Quotas) == 0 { + providers = []Provider{ProviderGemini} + } + + for _, p := range providers { + if profile, ok := profiles[p]; ok { + maps.Copy(rl.Quotas, profile.Models) + } + } + + maps.Copy(rl.Quotas, cfg.Quotas) +} + +func normaliseBackend(backend string) (string, error) { + switch strings.ToLower(strings.TrimSpace(backend)) { + case "", backendYAML: + return backendYAML, nil + case backendSQLite: + return backendSQLite, nil + default: + return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil) + } +} + +func defaultStatePath(backend string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + fileName := defaultYAMLStateFile + if backend == backendSQLite { + fileName = defaultSQLiteStateFile + } + + return filepath.Join(home, defaultStateDirName, fileName), nil +} + +func safeTokenSum(a, b int) int { + return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}}) +} + +func totalTokenCount(tokens []TokenEntry) int { + return safeTokenTotal(tokens) +} + +func safeTokenTotal(tokens []TokenEntry) int { + const maxInt = int(^uint(0) >> 1) + + total := 0 + for _, token := range tokens { + count := token.Count + if count < 0 { + continue + } + if total > maxInt-count { + return maxInt + } + total += count + } + return total +} + +func countTokensURL(baseURL, model string) (string, error) { + if strings.TrimSpace(model) == "" { + return "", fmt.Errorf("empty model") + } + + parsed, err := url.Parse(baseURL) + if err != nil { + return "", err + } + if parsed.Scheme == "" || parsed.Host == "" { + return "", fmt.Errorf("invalid base url") + } + + return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil +} + +func readLimitedBody(r io.Reader, limit int64) (string, error) { + body, err := io.ReadAll(io.LimitReader(r, limit+1)) + if err != nil { + return "", err + } + + truncated := int64(len(body)) > limit + if truncated { + body = body[:limit] + } + + result := string(body) + if truncated { + result += "..." + } + return result, nil +} diff --git a/ratelimit_test.go b/ratelimit_test.go index 838c484..73aac70 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter { return rl } +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + // --- Phase 0: CanSend boundary conditions --- func TestCanSend(t *testing.T) { @@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) { rl.RecordUsage(model, 50, 50) // exactly 100 tokens assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit") }) + + t.Run("negative estimated tokens are rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "negative-est" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100} + rl.RecordUsage(model, 50, 50) + + assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits") + }) } // --- Phase 0: Sliding window / prune tests --- @@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) { assert.Len(t, stats.Tokens, 2) assert.Equal(t, 6, stats.DayCount) }) + + t.Run("negative token inputs are clamped to zero", func(t *testing.T) { + rl := newTestLimiter(t) + model := "record-negative" + + rl.RecordUsage(model, -100, 25) + + stats := rl.State[model] + require.NotNil(t, stats) + assert.Len(t, stats.Requests, 1) + assert.Len(t, stats.Tokens, 1) + assert.Equal(t, 25, stats.Tokens[0].Count) + }) } // --- Phase 0: Reset --- @@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) { err := rl.WaitForCapacity(ctx, model, 10) assert.Error(t, err, "should return error for already-cancelled context") }) + + t.Run("negative tokens return error", func(t *testing.T) { + rl := newTestLimiter(t) + err := rl.WaitForCapacity(context.Background(), "wait-negative", -1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "negative tokens") + }) +} + +func TestNilUsageStats(t *testing.T) { + t.Run("CanSend replaces nil state without panicking", func(t *testing.T) { + rl := newTestLimiter(t) + model := "nil-cansend" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10} + rl.State[model] = nil + + assert.NotPanics(t, func() { + assert.True(t, rl.CanSend(model, 10)) + }) + assert.NotNil(t, rl.State[model]) + }) + + t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) { + rl := newTestLimiter(t) + model := "nil-record" + rl.State[model] = nil + + assert.NotPanics(t, func() { + rl.RecordUsage(model, 10, 10) + }) + require.NotNil(t, rl.State[model]) + assert.Equal(t, 1, rl.State[model].DayCount) + }) + + t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) { + rl := newTestLimiter(t) + rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3} + rl.State["nil-stats"] = nil + rl.State["nil-all-stats"] = nil + + assert.NotPanics(t, func() { + stats := rl.Stats("nil-stats") + assert.Equal(t, 1, stats.MaxRPM) + assert.Equal(t, 0, stats.TPM) + }) + + assert.NotPanics(t, func() { + all := rl.AllStats() + assert.Contains(t, all, "nil-stats") + assert.Contains(t, all, "nil-all-stats") + }) + }) } // --- Phase 0: Stats --- @@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) { _, exists := rl.State[model] return !exists }, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned") + + t.Run("non-positive interval is a safe no-op", func(t *testing.T) { + rl := newTestLimiter(t) + rl.State["still-here"] = &UsageStats{ + Requests: []time.Time{time.Now().Add(-2 * time.Minute)}, + } + + assert.NotPanics(t, func() { + stop := rl.BackgroundPrune(0) + stop() + }) + assert.Contains(t, rl.State, "still-here") + }) } // --- Phase 0: CountTokens (with mock HTTP server) --- @@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) { assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key")) + assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath()) + + var body struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + } + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.Len(t, body.Contents, 1) + require.Len(t, body.Contents[0].Parts, 1) + assert.Equal(t, "hello", body.Contents[0].Parts[0].Text) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}) + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})) })) defer server.Close() - // For testing purposes, we would need to make the base URL configurable. - // Since we're just checking the signature and basic logic, we test the error paths. + tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello") + require.NoError(t, err) + assert.Equal(t, 42, tokens) }) - t.Run("API error returns error", func(t *testing.T) { + t.Run("model name is path escaped", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath()) + assert.Empty(t, r.URL.RawQuery) + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7})) + })) + defer server.Close() + + tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello") + require.NoError(t, err) + assert.Equal(t, 7, tokens) + }) + + t.Run("API error body is truncated", func(t *testing.T) { + largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, `{"error": "invalid API key"}`) + _, err := fmt.Fprint(w, largeBody) + require.NoError(t, err) })) defer server.Close() - _, err := CountTokens(context.Background(), "fake-key", "test-model", "hello") - assert.Error(t, err, "should fail with invalid API endpoint") + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "api error status 401") + assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded") + assert.Contains(t, err.Error(), "...") }) + + t.Run("empty model is rejected before request", func(t *testing.T) { + _, err := CountTokens(context.Background(), "fake-key", "", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("invalid base URL returns error", func(t *testing.T) { + _, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("base URL without host returns error", func(t *testing.T) { + _, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("invalid JSON response returns error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"totalTokens":`)) + require.NoError(t, err) + })) + defer server.Close() + + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "decode response") + }) + + t.Run("error body read failures are returned", func(t *testing.T) { + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: io.NopCloser(errReader{}), + Header: make(http.Header), + }, nil + }), + } + + _, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "read error body") + }) + + t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11})) + })) + defer server.Close() + + originalClient := http.DefaultClient + http.DefaultClient = server.Client() + defer func() { + http.DefaultClient = originalClient + }() + + tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello") + require.NoError(t, err) + assert.Equal(t, 11, tokens) + }) +} + +func TestPersistSkipsNilState(t *testing.T) { + path := filepath.Join(t.TempDir(), "nil-state.yaml") + + rl, err := New() + require.NoError(t, err) + rl.filePath = path + rl.State["nil-model"] = nil + + require.NoError(t, rl.Persist()) + + rl2, err := New() + require.NoError(t, err) + rl2.filePath = path + require.NoError(t, rl2.Load()) + assert.NotContains(t, rl2.State, "nil-model") +} + +func TestTokenTotals(t *testing.T) { + maxInt := int(^uint(0) >> 1) + + assert.Equal(t, 25, safeTokenSum(-100, 25)) + assert.Equal(t, maxInt, safeTokenSum(maxInt, 1)) + assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}})) } // --- Phase 0: Benchmarks --- @@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) { assert.Equal(t, 5, q.MaxRPM) assert.Equal(t, 50000, q.MaxTPM) }) + + t.Run("invalid backend returns error", func(t *testing.T) { + _, err := NewWithConfig(Config{ + Backend: "bogus", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown backend") + }) + + t.Run("default YAML path uses home directory", func(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", "") + t.Setenv("home", "") + + rl, err := NewWithConfig(Config{}) + require.NoError(t, err) + assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath) + }) } func TestNewBackwardCompatibility(t *testing.T) { diff --git a/sqlite.go b/sqlite.go index de6c89d..863ede4 100644 --- a/sqlite.go +++ b/sqlite.go @@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error { return nil } -// saveQuotas upserts all quotas into the quotas table. +// saveQuotas writes a complete quota snapshot to the quotas table. func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { tx, err := s.db.Begin() if err != nil { @@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { } defer tx.Rollback() - stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) - VALUES (?, ?, ?, ?) - ON CONFLICT(model) DO UPDATE SET - max_rpm = excluded.max_rpm, - max_tpm = excluded.max_tpm, - max_rpd = excluded.max_rpd`) - if err != nil { - return coreerr.E("ratelimit.saveQuotas", "prepare", err) - } - defer stmt.Close() - - for model, q := range quotas { - if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil { - return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err) - } + if _, err := tx.Exec("DELETE FROM quotas"); err != nil { + return coreerr.E("ratelimit.saveQuotas", "clear", err) } - return tx.Commit() + if err := insertQuotas(tx, quotas); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveQuotas") } // loadQuotas reads all rows from the quotas table. @@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) { return result, nil } +// saveSnapshot writes quotas and state as a single atomic snapshot. +func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error { + tx, err := s.db.Begin() + if err != nil { + return coreerr.E("ratelimit.saveSnapshot", "begin", err) + } + defer tx.Rollback() + + if err := clearSnapshotTables(tx, true); err != nil { + return err + } + + if err := insertQuotas(tx, quotas); err != nil { + return err + } + if err := insertState(tx, state); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveSnapshot") +} + // saveState writes all usage state to SQLite in a single transaction. // It uses a truncate-and-insert approach for simplicity in this version, // but ensures atomicity via a single transaction. @@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { } defer tx.Rollback() - // Clear existing state in the transaction. + if err := clearSnapshotTables(tx, false); err != nil { + return err + } + + if err := insertState(tx, state); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveState") +} + +func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error { + if includeQuotas { + if _, err := tx.Exec("DELETE FROM quotas"); err != nil { + return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err) + } + } if _, err := tx.Exec("DELETE FROM requests"); err != nil { return coreerr.E("ratelimit.saveState", "clear requests", err) } @@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { if _, err := tx.Exec("DELETE FROM daily"); err != nil { return coreerr.E("ratelimit.saveState", "clear daily", err) } + return nil +} +func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error { + stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)") + if err != nil { + return coreerr.E("ratelimit.saveQuotas", "prepare", err) + } + defer stmt.Close() + + for model, q := range quotas { + if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil { + return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err) + } + } + return nil +} + +func insertState(tx *sql.Tx, state map[string]*UsageStats) error { reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)") if err != nil { return coreerr.E("ratelimit.saveState", "prepare requests", err) @@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { defer dayStmt.Close() for model, stats := range state { + if stats == nil { + continue + } for _, t := range stats.Requests { if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil { return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err) @@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err) } } + return nil +} +func commitTx(tx *sql.Tx, scope string) error { if err := tx.Commit(); err != nil { - return coreerr.E("ratelimit.saveState", "commit", err) + return coreerr.E(scope, "commit", err) } return nil } diff --git a/sqlite_test.go b/sqlite_test.go index e2e13a6..7026261 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) { assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)") } +func TestConfigBackendSQLite_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "config-backend.db") + rl, err := NewWithConfig(Config{ + Backend: backendSQLite, + FilePath: dbPath, + Quotas: map[string]ModelQuota{ + "backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, + }, + }) + require.NoError(t, err) + defer rl.Close() + + require.NotNil(t, rl.sqlite) + rl.RecordUsage("backend-model", 10, 10) + require.NoError(t, rl.Persist()) + + _, statErr := os.Stat(dbPath) + assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path") +} + +func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", "") + t.Setenv("home", "") + + rl, err := NewWithConfig(Config{ + Backend: backendSQLite, + }) + require.NoError(t, err) + defer rl.Close() + + require.NotNil(t, rl.sqlite) + require.NoError(t, rl.Persist()) + + _, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile)) + assert.NoError(t, statErr, "sqlite backend should use the default home DB path") +} + // --- Phase 2: MigrateYAMLToSQLite --- func TestMigrateYAMLToSQLite_Good(t *testing.T) { @@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) { }) } +func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) { + tmpDir := t.TempDir() + yamlPath := filepath.Join(tmpDir, "atomic.yaml") + sqlitePath := filepath.Join(tmpDir, "atomic.db") + now := time.Now().UTC() + + store, err := newSQLiteStore(sqlitePath) + require.NoError(t, err) + + originalQuotas := map[string]ModelQuota{ + "old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}, + } + originalState := map[string]*UsageStats{ + "old-model": { + Requests: []time.Time{now}, + Tokens: []TokenEntry{{Time: now, Count: 9}}, + DayStart: now, + DayCount: 1, + }, + } + require.NoError(t, store.saveSnapshot(originalQuotas, originalState)) + + _, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily + BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`) + require.NoError(t, err) + require.NoError(t, store.close()) + + migrated := &RateLimiter{ + Quotas: map[string]ModelQuota{ + "new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30}, + }, + State: map[string]*UsageStats{ + "new-model": { + Requests: []time.Time{now.Add(5 * time.Second)}, + Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}}, + DayStart: now.Add(5 * time.Second), + DayCount: 2, + }, + }, + } + data, err := yaml.Marshal(migrated) + require.NoError(t, err) + require.NoError(t, os.WriteFile(yamlPath, data, 0o644)) + + err = MigrateYAMLToSQLite(yamlPath, sqlitePath) + require.Error(t, err) + + store, err = newSQLiteStore(sqlitePath) + require.NoError(t, err) + defer store.close() + + quotas, err := store.loadQuotas() + require.NoError(t, err) + assert.Equal(t, originalQuotas, quotas) + + state, err := store.loadState() + require.NoError(t, err) + require.Contains(t, state, "old-model") + assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount) + assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count) + assert.NotContains(t, state, "new-model") +} + func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { tmpDir := t.TempDir() yamlPath := filepath.Join(tmpDir, "full.yaml") @@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) { assert.Equal(t, 5, custom.MaxRPM) } +func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "replace.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Quotas: map[string]ModelQuota{ + "model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, + }, + }) + require.NoError(t, err) + + rl.RecordUsage("model-a", 10, 10) + require.NoError(t, rl.Persist()) + + delete(rl.Quotas, "model-a") + rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20} + rl.Reset("") + rl.RecordUsage("model-b", 5, 5) + require.NoError(t, rl.Persist()) + require.NoError(t, rl.Close()) + + rl2, err := NewWithSQLiteConfig(dbPath, Config{ + Providers: []Provider{ProviderGemini}, + }) + require.NoError(t, err) + defer rl2.Close() + + rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99} + require.NoError(t, rl2.Load()) + + assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview") + assert.NotContains(t, rl2.Quotas, "model-a") + assert.Contains(t, rl2.Quotas, "model-b") + assert.NotContains(t, rl2.State, "stale-memory") + assert.NotContains(t, rl2.State, "model-a") + assert.Equal(t, 1, rl2.Stats("model-b").RPD) +} + +func TestSQLitePersistAtomic_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "persist-atomic.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Quotas: map[string]ModelQuota{ + "old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, + }, + }) + require.NoError(t, err) + + rl.RecordUsage("old-model", 10, 10) + require.NoError(t, rl.Persist()) + + _, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily + BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`) + require.NoError(t, err) + + delete(rl.Quotas, "old-model") + rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20} + rl.Reset("") + rl.RecordUsage("new-model", 50, 50) + + err = rl.Persist() + require.Error(t, err) + require.NoError(t, rl.Close()) + + rl2, err := NewWithSQLiteConfig(dbPath, Config{}) + require.NoError(t, err) + defer rl2.Close() + require.NoError(t, rl2.Load()) + + assert.Contains(t, rl2.Quotas, "old-model") + assert.NotContains(t, rl2.Quotas, "new-model") + assert.Equal(t, 1, rl2.Stats("old-model").RPD) + assert.Equal(t, 0, rl2.Stats("new-model").RPD) +} + // --- Phase 2: Benchmark --- func BenchmarkSQLitePersist(b *testing.B) {