diff --git a/ratelimit.go b/ratelimit.go index d6c076a..f0cf15c 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -235,38 +235,60 @@ func (rl *RateLimiter) loadSQLite() error { return nil } -// Persist writes the state to disk (YAML) or database (SQLite). +// Persist writes a snapshot of the state to disk (YAML) or database (SQLite). +// It clones the state under a lock and performs I/O without blocking other callers. func (rl *RateLimiter) Persist() error { - rl.mu.RLock() - defer rl.mu.RUnlock() + rl.mu.Lock() + quotas := maps.Clone(rl.Quotas) + state := make(map[string]*UsageStats, len(rl.State)) + for k, v := range rl.State { + state[k] = &UsageStats{ + Requests: slices.Clone(v.Requests), + Tokens: slices.Clone(v.Tokens), + DayStart: v.DayStart, + DayCount: v.DayCount, + } + } + sqlite := rl.sqlite + filePath := rl.filePath + rl.mu.Unlock() - if rl.sqlite != nil { - return rl.persistSQLite() + if sqlite != nil { + if err := sqlite.saveQuotas(quotas); err != nil { + return fmt.Errorf("ratelimit.Persist: sqlite quotas: %w", err) + } + if err := sqlite.saveState(state); err != nil { + return fmt.Errorf("ratelimit.Persist: sqlite state: %w", err) + } + return nil } - data, err := yaml.Marshal(rl) + // For YAML, we marshal the entire RateLimiter, but since we want to avoid + // holding the lock during marshal, we marshal a temporary struct. + data, err := yaml.Marshal(struct { + Quotas map[string]ModelQuota `yaml:"quotas"` + State map[string]*UsageStats `yaml:"state"` + }{ + Quotas: quotas, + State: state, + }) if err != nil { - return err + return fmt.Errorf("ratelimit.Persist: marshal: %w", err) } - dir := filepath.Dir(rl.filePath) + dir := filepath.Dir(filePath) if err := os.MkdirAll(dir, 0755); err != nil { - return err + return fmt.Errorf("ratelimit.Persist: mkdir: %w", err) } - return os.WriteFile(rl.filePath, data, 0644) -} - -// persistSQLite writes quotas and state to the SQLite backend. -// Caller must hold the read lock. -func (rl *RateLimiter) persistSQLite() error { - if err := rl.sqlite.saveQuotas(rl.Quotas); err != nil { - return err + if err := os.WriteFile(filePath, data, 0644); err != nil { + return fmt.Errorf("ratelimit.Persist: write: %w", err) } - return rl.sqlite.saveState(rl.State) + return nil } -// prune removes entries older than the sliding window (1 minute). +// prune removes entries older than the sliding window (1 minute) and removes +// empty state for models that haven't been used recently. // Caller must hold lock. func (rl *RateLimiter) prune(model string) { stats, ok := rl.State[model] @@ -279,12 +301,12 @@ func (rl *RateLimiter) prune(model string) { // Prune requests stats.Requests = slices.DeleteFunc(stats.Requests, func(t time.Time) bool { - return !t.After(window) + return t.Before(window) }) // Prune tokens stats.Tokens = slices.DeleteFunc(stats.Tokens, func(t TokenEntry) bool { - return !t.Time.After(window) + return t.Time.Before(window) }) // Reset daily counter if day has passed @@ -292,6 +314,39 @@ func (rl *RateLimiter) prune(model string) { stats.DayStart = now stats.DayCount = 0 } + + // If everything is empty and it's been more than a minute since last activity, + // delete the model state entirely to prevent memory leaks. + if len(stats.Requests) == 0 && len(stats.Tokens) == 0 { + // We could use a more sophisticated TTL here, but for now just cleanup empty ones. + // Note: we don't delete if DayCount > 0 and it's still the same day. + if stats.DayCount == 0 { + delete(rl.State, model) + } + } +} + +// BackgroundPrune starts a goroutine that periodically prunes all model states. +// It returns a function to stop the pruner. +func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rl.mu.Lock() + for m := range rl.State { + rl.prune(m) + } + rl.mu.Unlock() + } + } + }() + return cancel } // CanSend checks if a request can be sent without violating limits. @@ -309,15 +364,12 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { return true } - // Ensure state exists - if _, ok := rl.State[model]; !ok { - rl.State[model] = &UsageStats{ - DayStart: time.Now(), - } - } - rl.prune(model) - stats := rl.State[model] + stats, ok := rl.State[model] + if !ok { + stats = &UsageStats{DayStart: time.Now()} + rl.State[model] = stats + } // Check RPD if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD { @@ -348,15 +400,14 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) rl.mu.Lock() defer rl.mu.Unlock() - if _, ok := rl.State[model]; !ok { - rl.State[model] = &UsageStats{ - DayStart: time.Now(), - } + rl.prune(model) // Prune before recording to ensure we're not exceeding limits immediately after + stats, ok := rl.State[model] + if !ok { + stats = &UsageStats{DayStart: time.Now()} + rl.State[model] = stats } - stats := rl.State[model] now := time.Now() - stats.Requests = append(stats.Requests, now) stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) stats.DayCount++ @@ -404,24 +455,28 @@ type ModelStats struct { DayStart time.Time } -// Models returns an iterator over all model names tracked by the limiter. +// Models returns a sorted iterator over all model names tracked by the limiter. func (rl *RateLimiter) Models() iter.Seq[string] { - return func(yield func(string) bool) { - stats := rl.AllStats() - for m := range stats { - if !yield(m) { - return - } + rl.mu.RLock() + defer rl.mu.RUnlock() + + // Use maps.Keys and slices.Sorted for idiomatic Go 1.26+ + keys := slices.Collect(maps.Keys(rl.Quotas)) + for m := range rl.State { + if _, ok := rl.Quotas[m]; !ok { + keys = append(keys, m) } } + slices.Sort(keys) + return slices.Values(keys) } -// Iter returns an iterator over all model names and their current stats. +// Iter returns a sorted iterator over all model names and their current stats. func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats] { return func(yield func(string, ModelStats) bool) { stats := rl.AllStats() - for k, v := range stats { - if !yield(k, v) { + for _, m := range slices.Sorted(maps.Keys(stats)) { + if !yield(m, stats[m]) { return } } @@ -595,7 +650,7 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { } // CountTokens calls the Google API to count tokens for a prompt. -func CountTokens(apiKey, model, text string) (int, error) { +func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model) reqBody := map[string]any{ @@ -610,32 +665,32 @@ func CountTokens(apiKey, model, text string) (int, error) { jsonBody, err := json.Marshal(reqBody) if err != nil { - return 0, err + return 0, fmt.Errorf("ratelimit.CountTokens: marshal request: %w", err) } - req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody)) if err != nil { - return 0, err + return 0, fmt.Errorf("ratelimit.CountTokens: new request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("x-goog-api-key", apiKey) resp, err := http.DefaultClient.Do(req) if err != nil { - return 0, err + return 0, fmt.Errorf("ratelimit.CountTokens: do request: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - return 0, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body)) + return 0, fmt.Errorf("ratelimit.CountTokens: API error (status %d): %s", resp.StatusCode, string(body)) } var result struct { TotalTokens int `json:"totalTokens"` } if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return 0, err + return 0, fmt.Errorf("ratelimit.CountTokens: decode response: %w", err) } return result.TotalTokens, nil diff --git a/ratelimit_test.go b/ratelimit_test.go index 78e6ad9..838c484 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -716,6 +716,30 @@ func TestConcurrentResetAndRecord(t *testing.T) { // No assertion needed -- if we get here without -race flagging, mutex is sound } +func TestBackgroundPrune(t *testing.T) { + rl := newTestLimiter(t) + model := "prune-me" + rl.Quotas[model] = ModelQuota{MaxRPM: 100} + + // Set state with old usage. + old := time.Now().Add(-2 * time.Minute) + rl.State[model] = &UsageStats{ + Requests: []time.Time{old}, + Tokens: []TokenEntry{{Time: old, Count: 100}}, + } + + stop := rl.BackgroundPrune(10 * time.Millisecond) + defer stop() + + // Wait for pruner to run. + assert.Eventually(t, func() bool { + rl.mu.Lock() + defer rl.mu.Unlock() + _, exists := rl.State[model] + return !exists + }, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned") +} + // --- Phase 0: CountTokens (with mock HTTP server) --- func TestCountTokens(t *testing.T) { @@ -730,10 +754,8 @@ func TestCountTokens(t *testing.T) { })) defer server.Close() - // We need to override the URL. Since CountTokens hardcodes the Google API URL, - // we test it via the exported function with a test server. - // For proper unit testing, we would need to make the base URL configurable. - // For now, test the error paths that don't require a real API. + // For testing purposes, we would need to make the base URL configurable. + // Since we're just checking the signature and basic logic, we test the error paths. }) t.Run("API error returns error", func(t *testing.T) { @@ -743,9 +765,7 @@ func TestCountTokens(t *testing.T) { })) defer server.Close() - // Can't test directly due to hardcoded URL, but we can verify error - // handling with an unreachable endpoint - _, err := CountTokens("fake-key", "test-model", "hello") + _, err := CountTokens(context.Background(), "fake-key", "test-model", "hello") assert.Error(t, err, "should fail with invalid API endpoint") }) } diff --git a/sqlite.go b/sqlite.go index f1bb283..f969f5e 100644 --- a/sqlite.go +++ b/sqlite.go @@ -129,7 +129,8 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) { } // saveState writes all usage state to SQLite in a single transaction. -// It deletes existing rows and inserts fresh data for each model. +// It uses a truncate-and-insert approach for simplicity in this version, +// but ensures atomicity via a single transaction. func (s *sqliteStore) saveState(state map[string]*UsageStats) error { tx, err := s.db.Begin() if err != nil { @@ -137,7 +138,7 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { } defer tx.Rollback() - // Clear existing state. + // Clear existing state in the transaction. if _, err := tx.Exec("DELETE FROM requests"); err != nil { return fmt.Errorf("ratelimit.saveState: clear requests: %w", err) } @@ -182,7 +183,10 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { } } - return tx.Commit() + if err := tx.Commit(); err != nil { + return fmt.Errorf("ratelimit.saveState: commit: %w", err) + } + return nil } // loadState reconstructs the UsageStats map from SQLite tables.