From d1c90b937d675dd91c65166d25060fee9fa1f059 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 07:26:15 +0000 Subject: [PATCH 1/4] fix(ratelimit): harden audit edge cases Co-Authored-By: Virgil --- iter_test.go | 39 ++++--- ratelimit.go | 287 +++++++++++++++++++++++++++++++--------------- ratelimit_test.go | 257 +++++++++++++++++++++++++++++++++++++++-- sqlite.go | 91 ++++++++++++--- sqlite_test.go | 174 ++++++++++++++++++++++++++++ 5 files changed, 715 insertions(+), 133 deletions(-) diff --git a/iter_test.go b/iter_test.go index aaa59be..f353e16 100644 --- a/iter_test.go +++ b/iter_test.go @@ -31,13 +31,19 @@ func TestIterators(t *testing.T) { assert.Contains(t, models, "model-a") assert.Contains(t, models, "model-b") assert.Contains(t, models, "model-c") - + // Check sorting of our specific models foundA, foundB, foundC := -1, -1, -1 for i, m := range models { - if m == "model-a" { foundA = i } - if m == "model-b" { foundB = i } - if m == "model-c" { foundC = i } + if m == "model-a" { + foundA = i + } + if m == "model-b" { + foundB = i + } + if m == "model-c" { + foundC = i + } } assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c") }) @@ -57,9 +63,15 @@ func TestIterators(t *testing.T) { // Check sorting foundA, foundB, foundC := -1, -1, -1 for i, m := range models { - if m == "model-a" { foundA = i } - if m == "model-b" { foundB = i } - if m == "model-c" { foundC = i } + if m == "model-a" { + foundA = i + } + if m == "model-b" { + foundB = i + } + if m == "model-c" { + foundC = i + } } assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c") }) @@ -99,9 +111,8 @@ func TestIterEarlyBreak(t *testing.T) { } func TestCountTokensFull(t *testing.T) { - t.Run("invalid URL/network error", func(t *testing.T) { - // Using an invalid character in model name to trigger URL error or similar - _, err := CountTokens(context.Background(), "key", "invalid model", "text") + t.Run("empty model is rejected", func(t *testing.T) { + _, err := CountTokens(context.Background(), "key", "", "text") assert.Error(t, err) }) @@ -112,15 +123,15 @@ func TestCountTokensFull(t *testing.T) { })) defer server.Close() - // We can't easily override the URL in CountTokens without changing the code, - // but we can test the logic if we make it slightly more testable. - // For now, I've already updated ratelimit_test.go with some of this. + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text") + require.Error(t, err) + assert.Contains(t, err.Error(), "status 400") }) t.Run("context cancelled", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := CountTokens(ctx, "key", "model", "text") + _, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text") assert.Error(t, err) assert.Contains(t, err.Error(), "do request") }) diff --git a/ratelimit.go b/ratelimit.go index d2d0007..0e34709 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -9,9 +9,11 @@ import ( "iter" "maps" "net/http" + "net/url" "os" "path/filepath" "slices" + "strings" "sync" "time" @@ -34,6 +36,16 @@ const ( ProviderLocal Provider = "local" ) +const ( + defaultStateDirName = ".core" + defaultYAMLStateFile = "ratelimits.yaml" + defaultSQLiteStateFile = "ratelimits.db" + backendYAML = "yaml" + backendSQLite = "sqlite" + countTokensErrorBodyLimit = 8 * 1024 + countTokensSuccessBodyLimit = 1 * 1024 * 1024 +) + // ModelQuota defines the rate limits for a specific model. type ModelQuota struct { MaxRPM int `yaml:"max_rpm"` // Requests per minute (0 = unlimited) @@ -143,39 +155,30 @@ func New() (*RateLimiter, error) { // NewWithConfig creates a RateLimiter from explicit configuration. // If no providers or quotas are specified, Gemini defaults are used. func NewWithConfig(cfg Config) (*RateLimiter, error) { + backend, err := normaliseBackend(cfg.Backend) + if err != nil { + return nil, err + } + filePath := cfg.FilePath if filePath == "" { - home, err := os.UserHomeDir() + filePath, err = defaultStatePath(backend) if err != nil { return nil, err } - filePath = filepath.Join(home, ".core", "ratelimits.yaml") } - rl := &RateLimiter{ - Quotas: make(map[string]ModelQuota), - State: make(map[string]*UsageStats), - filePath: filePath, - } - - // Load provider profiles - profiles := DefaultProfiles() - providers := cfg.Providers - - // If nothing specified at all, default to Gemini - if len(providers) == 0 && len(cfg.Quotas) == 0 { - providers = []Provider{ProviderGemini} - } - - for _, p := range providers { - if profile, ok := profiles[p]; ok { - maps.Copy(rl.Quotas, profile.Models) + if backend == backendSQLite { + if cfg.FilePath == "" { + if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { + return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err) + } } + return NewWithSQLiteConfig(filePath, cfg) } - // Merge explicit quotas on top (allows overrides) - maps.Copy(rl.Quotas, cfg.Quotas) - + rl := newConfiguredRateLimiter(cfg) + rl.filePath = filePath return rl, nil } @@ -225,15 +228,17 @@ func (rl *RateLimiter) loadSQLite() error { if err != nil { return err } - // Merge loaded quotas (loaded quotas override in-memory defaults). - maps.Copy(rl.Quotas, quotas) + // Persisted quotas are authoritative when present; otherwise keep config defaults. + if len(quotas) > 0 { + rl.Quotas = maps.Clone(quotas) + } state, err := rl.sqlite.loadState() if err != nil { return err } - // Replace in-memory state with persisted state. - maps.Copy(rl.State, state) + // Replace in-memory state with the persisted snapshot. + rl.State = state return nil } @@ -244,6 +249,9 @@ func (rl *RateLimiter) Persist() error { quotas := maps.Clone(rl.Quotas) state := make(map[string]*UsageStats, len(rl.State)) for k, v := range rl.State { + if v == nil { + continue + } state[k] = &UsageStats{ Requests: slices.Clone(v.Requests), Tokens: slices.Clone(v.Tokens), @@ -256,11 +264,8 @@ func (rl *RateLimiter) Persist() error { rl.mu.Unlock() if sqlite != nil { - if err := sqlite.saveQuotas(quotas); err != nil { - return coreerr.E("ratelimit.Persist", "sqlite quotas", err) - } - if err := sqlite.saveState(state); err != nil { - return coreerr.E("ratelimit.Persist", "sqlite state", err) + if err := sqlite.saveSnapshot(quotas, state); err != nil { + return coreerr.E("ratelimit.Persist", "sqlite snapshot", err) } return nil } @@ -292,6 +297,10 @@ func (rl *RateLimiter) prune(model string) { if !ok { return } + if stats == nil { + delete(rl.State, model) + return + } now := time.Now() window := now.Add(-1 * time.Minute) @@ -326,6 +335,10 @@ func (rl *RateLimiter) prune(model string) { // BackgroundPrune starts a goroutine that periodically prunes all model states. // It returns a function to stop the pruner. func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { + if interval <= 0 { + return func() {} + } + ctx, cancel := context.WithCancel(context.Background()) go func() { ticker := time.NewTicker(interval) @@ -348,6 +361,10 @@ func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func() { // CanSend checks if a request can be sent without violating limits. func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { + if estimatedTokens < 0 { + return false + } + rl.mu.Lock() defer rl.mu.Unlock() @@ -380,11 +397,8 @@ func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool { // Check TPM if quota.MaxTPM > 0 { - currentTokens := 0 - for _, t := range stats.Tokens { - currentTokens += t.Count - } - if currentTokens+estimatedTokens > quota.MaxTPM { + currentTokens := totalTokenCount(stats.Tokens) + if estimatedTokens > quota.MaxTPM || currentTokens > quota.MaxTPM-estimatedTokens { return false } } @@ -405,13 +419,18 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) } now := time.Now() + totalTokens := safeTokenSum(promptTokens, outputTokens) stats.Requests = append(stats.Requests, now) - stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens}) + stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: totalTokens}) stats.DayCount++ } // WaitForCapacity blocks until capacity is available or context is cancelled. func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { + if tokens < 0 { + return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil) + } + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -522,24 +541,8 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { result[m] = ModelStats{} } - now := time.Now() - window := now.Add(-1 * time.Minute) - for m := range result { - // Prune inline - if s, ok := rl.State[m]; ok { - s.Requests = slices.DeleteFunc(s.Requests, func(t time.Time) bool { - return !t.After(window) - }) - s.Tokens = slices.DeleteFunc(s.Tokens, func(t TokenEntry) bool { - return !t.Time.After(window) - }) - - if now.Sub(s.DayStart) >= 24*time.Hour { - s.DayStart = now - s.DayCount = 0 - } - } + rl.prune(m) ms := ModelStats{} if q, ok := rl.Quotas[m]; ok { @@ -547,13 +550,11 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { ms.MaxTPM = q.MaxTPM ms.MaxRPD = q.MaxRPD } - if s, ok := rl.State[m]; ok { + if s, ok := rl.State[m]; ok && s != nil { ms.RPM = len(s.Requests) ms.RPD = s.DayCount ms.DayStart = s.DayStart - for _, t := range s.Tokens { - ms.TPM += t.Count - } + ms.TPM = totalTokenCount(s.Tokens) } result[m] = ms } @@ -579,25 +580,8 @@ func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error) { return nil, err } - rl := &RateLimiter{ - Quotas: make(map[string]ModelQuota), - State: make(map[string]*UsageStats), - sqlite: store, - } - - // Load provider profiles. - profiles := DefaultProfiles() - providers := cfg.Providers - if len(providers) == 0 && len(cfg.Quotas) == 0 { - providers = []Provider{ProviderGemini} - } - for _, p := range providers { - if profile, ok := profiles[p]; ok { - maps.Copy(rl.Quotas, profile.Models) - } - } - maps.Copy(rl.Quotas, cfg.Quotas) - + rl := newConfiguredRateLimiter(cfg) + rl.sqlite = store return rl, nil } @@ -633,22 +617,22 @@ func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { } defer store.close() - if rl.Quotas != nil { - if err := store.saveQuotas(rl.Quotas); err != nil { - return err - } - } - if rl.State != nil { - if err := store.saveState(rl.State); err != nil { - return err - } + if err := store.saveSnapshot(rl.Quotas, rl.State); err != nil { + return err } return nil } // CountTokens calls the Google API to count tokens for a prompt. func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { - url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model) + return countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", apiKey, model, text) +} + +func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) { + requestURL, err := countTokensURL(baseURL, model) + if err != nil { + return 0, coreerr.E("ratelimit.CountTokens", "build url", err) + } reqBody := map[string]any{ "contents": []any{ @@ -665,30 +649,147 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) { return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody)) if err != nil { return 0, coreerr.E("ratelimit.CountTokens", "new request", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("x-goog-api-key", apiKey) - resp, err := http.DefaultClient.Do(req) + if client == nil { + client = http.DefaultClient + } + + resp, err := client.Do(req) if err != nil { return 0, coreerr.E("ratelimit.CountTokens", "do request", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("API error (status %d): %s", resp.StatusCode, string(body)), nil) + body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit) + if err != nil { + return 0, coreerr.E("ratelimit.CountTokens", "read error body", err) + } + return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil) } var result struct { TotalTokens int `json:"totalTokens"` } - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil { return 0, coreerr.E("ratelimit.CountTokens", "decode response", err) } return result.TotalTokens, nil } + +func newConfiguredRateLimiter(cfg Config) *RateLimiter { + rl := &RateLimiter{ + Quotas: make(map[string]ModelQuota), + State: make(map[string]*UsageStats), + } + applyConfig(rl, cfg) + return rl +} + +func applyConfig(rl *RateLimiter, cfg Config) { + profiles := DefaultProfiles() + providers := cfg.Providers + + if len(providers) == 0 && len(cfg.Quotas) == 0 { + providers = []Provider{ProviderGemini} + } + + for _, p := range providers { + if profile, ok := profiles[p]; ok { + maps.Copy(rl.Quotas, profile.Models) + } + } + + maps.Copy(rl.Quotas, cfg.Quotas) +} + +func normaliseBackend(backend string) (string, error) { + switch strings.ToLower(strings.TrimSpace(backend)) { + case "", backendYAML: + return backendYAML, nil + case backendSQLite: + return backendSQLite, nil + default: + return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil) + } +} + +func defaultStatePath(backend string) (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + fileName := defaultYAMLStateFile + if backend == backendSQLite { + fileName = defaultSQLiteStateFile + } + + return filepath.Join(home, defaultStateDirName, fileName), nil +} + +func safeTokenSum(a, b int) int { + return safeTokenTotal([]TokenEntry{{Count: a}, {Count: b}}) +} + +func totalTokenCount(tokens []TokenEntry) int { + return safeTokenTotal(tokens) +} + +func safeTokenTotal(tokens []TokenEntry) int { + const maxInt = int(^uint(0) >> 1) + + total := 0 + for _, token := range tokens { + count := token.Count + if count < 0 { + continue + } + if total > maxInt-count { + return maxInt + } + total += count + } + return total +} + +func countTokensURL(baseURL, model string) (string, error) { + if strings.TrimSpace(model) == "" { + return "", fmt.Errorf("empty model") + } + + parsed, err := url.Parse(baseURL) + if err != nil { + return "", err + } + if parsed.Scheme == "" || parsed.Host == "" { + return "", fmt.Errorf("invalid base url") + } + + return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil +} + +func readLimitedBody(r io.Reader, limit int64) (string, error) { + body, err := io.ReadAll(io.LimitReader(r, limit+1)) + if err != nil { + return "", err + } + + truncated := int64(len(body)) > limit + if truncated { + body = body[:limit] + } + + result := string(body) + if truncated { + result += "..." + } + return result, nil +} diff --git a/ratelimit_test.go b/ratelimit_test.go index 838c484..73aac70 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -4,10 +4,12 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -25,6 +27,18 @@ func newTestLimiter(t *testing.T) *RateLimiter { return rl } +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +type errReader struct{} + +func (errReader) Read([]byte) (int, error) { + return 0, io.ErrUnexpectedEOF +} + // --- Phase 0: CanSend boundary conditions --- func TestCanSend(t *testing.T) { @@ -162,6 +176,15 @@ func TestCanSend(t *testing.T) { rl.RecordUsage(model, 50, 50) // exactly 100 tokens assert.True(t, rl.CanSend(model, 0), "zero estimated tokens should fit even at limit") }) + + t.Run("negative estimated tokens are rejected", func(t *testing.T) { + rl := newTestLimiter(t) + model := "negative-est" + rl.Quotas[model] = ModelQuota{MaxRPM: 100, MaxTPM: 100, MaxRPD: 100} + rl.RecordUsage(model, 50, 50) + + assert.False(t, rl.CanSend(model, -1), "negative estimated tokens should not bypass TPM limits") + }) } // --- Phase 0: Sliding window / prune tests --- @@ -335,6 +358,19 @@ func TestRecordUsage(t *testing.T) { assert.Len(t, stats.Tokens, 2) assert.Equal(t, 6, stats.DayCount) }) + + t.Run("negative token inputs are clamped to zero", func(t *testing.T) { + rl := newTestLimiter(t) + model := "record-negative" + + rl.RecordUsage(model, -100, 25) + + stats := rl.State[model] + require.NotNil(t, stats) + assert.Len(t, stats.Requests, 1) + assert.Len(t, stats.Tokens, 1) + assert.Equal(t, 25, stats.Tokens[0].Count) + }) } // --- Phase 0: Reset --- @@ -422,6 +458,58 @@ func TestWaitForCapacity(t *testing.T) { err := rl.WaitForCapacity(ctx, model, 10) assert.Error(t, err, "should return error for already-cancelled context") }) + + t.Run("negative tokens return error", func(t *testing.T) { + rl := newTestLimiter(t) + err := rl.WaitForCapacity(context.Background(), "wait-negative", -1) + assert.Error(t, err) + assert.Contains(t, err.Error(), "negative tokens") + }) +} + +func TestNilUsageStats(t *testing.T) { + t.Run("CanSend replaces nil state without panicking", func(t *testing.T) { + rl := newTestLimiter(t) + model := "nil-cansend" + rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 10} + rl.State[model] = nil + + assert.NotPanics(t, func() { + assert.True(t, rl.CanSend(model, 10)) + }) + assert.NotNil(t, rl.State[model]) + }) + + t.Run("RecordUsage replaces nil state without panicking", func(t *testing.T) { + rl := newTestLimiter(t) + model := "nil-record" + rl.State[model] = nil + + assert.NotPanics(t, func() { + rl.RecordUsage(model, 10, 10) + }) + require.NotNil(t, rl.State[model]) + assert.Equal(t, 1, rl.State[model].DayCount) + }) + + t.Run("Stats and AllStats tolerate nil state entries", func(t *testing.T) { + rl := newTestLimiter(t) + rl.Quotas["nil-stats"] = ModelQuota{MaxRPM: 1, MaxTPM: 2, MaxRPD: 3} + rl.State["nil-stats"] = nil + rl.State["nil-all-stats"] = nil + + assert.NotPanics(t, func() { + stats := rl.Stats("nil-stats") + assert.Equal(t, 1, stats.MaxRPM) + assert.Equal(t, 0, stats.TPM) + }) + + assert.NotPanics(t, func() { + all := rl.AllStats() + assert.Contains(t, all, "nil-stats") + assert.Contains(t, all, "nil-all-stats") + }) + }) } // --- Phase 0: Stats --- @@ -738,6 +826,19 @@ func TestBackgroundPrune(t *testing.T) { _, exists := rl.State[model] return !exists }, 1*time.Second, 20*time.Millisecond, "old empty state should be pruned") + + t.Run("non-positive interval is a safe no-op", func(t *testing.T) { + rl := newTestLimiter(t) + rl.State["still-here"] = &UsageStats{ + Requests: []time.Time{time.Now().Add(-2 * time.Minute)}, + } + + assert.NotPanics(t, func() { + stop := rl.BackgroundPrune(0) + stop() + }) + assert.Contains(t, rl.State, "still-here") + }) } // --- Phase 0: CountTokens (with mock HTTP server) --- @@ -748,26 +849,149 @@ func TestCountTokens(t *testing.T) { assert.Equal(t, http.MethodPost, r.Method) assert.Equal(t, "application/json", r.Header.Get("Content-Type")) assert.Equal(t, "test-api-key", r.Header.Get("x-goog-api-key")) + assert.Equal(t, "/v1beta/models/test-model:countTokens", r.URL.EscapedPath()) + + var body struct { + Contents []struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + } `json:"contents"` + } + require.NoError(t, json.NewDecoder(r.Body).Decode(&body)) + require.Len(t, body.Contents, 1) + require.Len(t, body.Contents[0].Parts, 1) + assert.Equal(t, "hello", body.Contents[0].Parts[0].Text) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42}) + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 42})) })) defer server.Close() - // For testing purposes, we would need to make the base URL configurable. - // Since we're just checking the signature and basic logic, we test the error paths. + tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "test-model", "hello") + require.NoError(t, err) + assert.Equal(t, 42, tokens) }) - t.Run("API error returns error", func(t *testing.T) { + t.Run("model name is path escaped", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/v1beta/models/folder%2Fmodel%3Fdebug=1:countTokens", r.URL.EscapedPath()) + assert.Empty(t, r.URL.RawQuery) + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 7})) + })) + defer server.Close() + + tokens, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "test-api-key", "folder/model?debug=1", "hello") + require.NoError(t, err) + assert.Equal(t, 7, tokens) + }) + + t.Run("API error body is truncated", func(t *testing.T) { + largeBody := strings.Repeat("x", countTokensErrorBodyLimit+256) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, `{"error": "invalid API key"}`) + _, err := fmt.Fprint(w, largeBody) + require.NoError(t, err) })) defer server.Close() - _, err := CountTokens(context.Background(), "fake-key", "test-model", "hello") - assert.Error(t, err, "should fail with invalid API endpoint") + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "api error status 401") + assert.True(t, strings.Count(err.Error(), "x") < len(largeBody), "error body should be bounded") + assert.Contains(t, err.Error(), "...") }) + + t.Run("empty model is rejected before request", func(t *testing.T) { + _, err := CountTokens(context.Background(), "fake-key", "", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("invalid base URL returns error", func(t *testing.T) { + _, err := countTokensWithClient(context.Background(), http.DefaultClient, "://bad-url", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("base URL without host returns error", func(t *testing.T) { + _, err := countTokensWithClient(context.Background(), http.DefaultClient, "/relative", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "build url") + }) + + t.Run("invalid JSON response returns error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, err := w.Write([]byte(`{"totalTokens":`)) + require.NoError(t, err) + })) + defer server.Close() + + _, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "decode response") + }) + + t.Run("error body read failures are returned", func(t *testing.T) { + client := &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: io.NopCloser(errReader{}), + Header: make(http.Header), + }, nil + }), + } + + _, err := countTokensWithClient(context.Background(), client, "https://generativelanguage.googleapis.com", "fake-key", "test-model", "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "read error body") + }) + + t.Run("nil client falls back to http.DefaultClient", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]int{"totalTokens": 11})) + })) + defer server.Close() + + originalClient := http.DefaultClient + http.DefaultClient = server.Client() + defer func() { + http.DefaultClient = originalClient + }() + + tokens, err := countTokensWithClient(context.Background(), nil, server.URL, "fake-key", "test-model", "hello") + require.NoError(t, err) + assert.Equal(t, 11, tokens) + }) +} + +func TestPersistSkipsNilState(t *testing.T) { + path := filepath.Join(t.TempDir(), "nil-state.yaml") + + rl, err := New() + require.NoError(t, err) + rl.filePath = path + rl.State["nil-model"] = nil + + require.NoError(t, rl.Persist()) + + rl2, err := New() + require.NoError(t, err) + rl2.filePath = path + require.NoError(t, rl2.Load()) + assert.NotContains(t, rl2.State, "nil-model") +} + +func TestTokenTotals(t *testing.T) { + maxInt := int(^uint(0) >> 1) + + assert.Equal(t, 25, safeTokenSum(-100, 25)) + assert.Equal(t, maxInt, safeTokenSum(maxInt, 1)) + assert.Equal(t, 10, totalTokenCount([]TokenEntry{{Count: -5}, {Count: 10}})) } // --- Phase 0: Benchmarks --- @@ -983,6 +1207,25 @@ func TestNewWithConfig(t *testing.T) { assert.Equal(t, 5, q.MaxRPM) assert.Equal(t, 50000, q.MaxTPM) }) + + t.Run("invalid backend returns error", func(t *testing.T) { + _, err := NewWithConfig(Config{ + Backend: "bogus", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "unknown backend") + }) + + t.Run("default YAML path uses home directory", func(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", "") + t.Setenv("home", "") + + rl, err := NewWithConfig(Config{}) + require.NoError(t, err) + assert.Equal(t, filepath.Join(home, defaultStateDirName, defaultYAMLStateFile), rl.filePath) + }) } func TestNewBackwardCompatibility(t *testing.T) { diff --git a/sqlite.go b/sqlite.go index de6c89d..863ede4 100644 --- a/sqlite.go +++ b/sqlite.go @@ -78,7 +78,7 @@ func createSchema(db *sql.DB) error { return nil } -// saveQuotas upserts all quotas into the quotas table. +// saveQuotas writes a complete quota snapshot to the quotas table. func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { tx, err := s.db.Begin() if err != nil { @@ -86,24 +86,15 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { } defer tx.Rollback() - stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) - VALUES (?, ?, ?, ?) - ON CONFLICT(model) DO UPDATE SET - max_rpm = excluded.max_rpm, - max_tpm = excluded.max_tpm, - max_rpd = excluded.max_rpd`) - if err != nil { - return coreerr.E("ratelimit.saveQuotas", "prepare", err) - } - defer stmt.Close() - - for model, q := range quotas { - if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil { - return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err) - } + if _, err := tx.Exec("DELETE FROM quotas"); err != nil { + return coreerr.E("ratelimit.saveQuotas", "clear", err) } - return tx.Commit() + if err := insertQuotas(tx, quotas); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveQuotas") } // loadQuotas reads all rows from the quotas table. @@ -129,6 +120,28 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) { return result, nil } +// saveSnapshot writes quotas and state as a single atomic snapshot. +func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error { + tx, err := s.db.Begin() + if err != nil { + return coreerr.E("ratelimit.saveSnapshot", "begin", err) + } + defer tx.Rollback() + + if err := clearSnapshotTables(tx, true); err != nil { + return err + } + + if err := insertQuotas(tx, quotas); err != nil { + return err + } + if err := insertState(tx, state); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveSnapshot") +} + // saveState writes all usage state to SQLite in a single transaction. // It uses a truncate-and-insert approach for simplicity in this version, // but ensures atomicity via a single transaction. @@ -139,7 +152,23 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { } defer tx.Rollback() - // Clear existing state in the transaction. + if err := clearSnapshotTables(tx, false); err != nil { + return err + } + + if err := insertState(tx, state); err != nil { + return err + } + + return commitTx(tx, "ratelimit.saveState") +} + +func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error { + if includeQuotas { + if _, err := tx.Exec("DELETE FROM quotas"); err != nil { + return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err) + } + } if _, err := tx.Exec("DELETE FROM requests"); err != nil { return coreerr.E("ratelimit.saveState", "clear requests", err) } @@ -149,7 +178,25 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { if _, err := tx.Exec("DELETE FROM daily"); err != nil { return coreerr.E("ratelimit.saveState", "clear daily", err) } + return nil +} +func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error { + stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)") + if err != nil { + return coreerr.E("ratelimit.saveQuotas", "prepare", err) + } + defer stmt.Close() + + for model, q := range quotas { + if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil { + return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err) + } + } + return nil +} + +func insertState(tx *sql.Tx, state map[string]*UsageStats) error { reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)") if err != nil { return coreerr.E("ratelimit.saveState", "prepare requests", err) @@ -169,6 +216,9 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { defer dayStmt.Close() for model, stats := range state { + if stats == nil { + continue + } for _, t := range stats.Requests { if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil { return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err) @@ -183,9 +233,12 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error { return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err) } } + return nil +} +func commitTx(tx *sql.Tx, scope string) error { if err := tx.Commit(); err != nil { - return coreerr.E("ratelimit.saveState", "commit", err) + return coreerr.E(scope, "commit", err) } return nil } diff --git a/sqlite_test.go b/sqlite_test.go index e2e13a6..7026261 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -435,6 +435,45 @@ func TestConfigBackendDefault_Good(t *testing.T) { assert.Nil(t, rl.sqlite, "empty backend should use YAML (no sqlite)") } +func TestConfigBackendSQLite_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "config-backend.db") + rl, err := NewWithConfig(Config{ + Backend: backendSQLite, + FilePath: dbPath, + Quotas: map[string]ModelQuota{ + "backend-model": {MaxRPM: 10, MaxTPM: 1000, MaxRPD: 50}, + }, + }) + require.NoError(t, err) + defer rl.Close() + + require.NotNil(t, rl.sqlite) + rl.RecordUsage("backend-model", 10, 10) + require.NoError(t, rl.Persist()) + + _, statErr := os.Stat(dbPath) + assert.NoError(t, statErr, "sqlite backend should persist to the configured DB path") +} + +func TestConfigBackendSQLiteDefaultPath_Good(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", "") + t.Setenv("home", "") + + rl, err := NewWithConfig(Config{ + Backend: backendSQLite, + }) + require.NoError(t, err) + defer rl.Close() + + require.NotNil(t, rl.sqlite) + require.NoError(t, rl.Persist()) + + _, statErr := os.Stat(filepath.Join(home, defaultStateDirName, defaultSQLiteStateFile)) + assert.NoError(t, statErr, "sqlite backend should use the default home DB path") +} + // --- Phase 2: MigrateYAMLToSQLite --- func TestMigrateYAMLToSQLite_Good(t *testing.T) { @@ -492,6 +531,69 @@ func TestMigrateYAMLToSQLite_Bad(t *testing.T) { }) } +func TestMigrateYAMLToSQLiteAtomic_Good(t *testing.T) { + tmpDir := t.TempDir() + yamlPath := filepath.Join(tmpDir, "atomic.yaml") + sqlitePath := filepath.Join(tmpDir, "atomic.db") + now := time.Now().UTC() + + store, err := newSQLiteStore(sqlitePath) + require.NoError(t, err) + + originalQuotas := map[string]ModelQuota{ + "old-model": {MaxRPM: 1, MaxTPM: 2, MaxRPD: 3}, + } + originalState := map[string]*UsageStats{ + "old-model": { + Requests: []time.Time{now}, + Tokens: []TokenEntry{{Time: now, Count: 9}}, + DayStart: now, + DayCount: 1, + }, + } + require.NoError(t, store.saveSnapshot(originalQuotas, originalState)) + + _, err = store.db.Exec(`CREATE TRIGGER fail_daily_migrate BEFORE INSERT ON daily + BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`) + require.NoError(t, err) + require.NoError(t, store.close()) + + migrated := &RateLimiter{ + Quotas: map[string]ModelQuota{ + "new-model": {MaxRPM: 10, MaxTPM: 20, MaxRPD: 30}, + }, + State: map[string]*UsageStats{ + "new-model": { + Requests: []time.Time{now.Add(5 * time.Second)}, + Tokens: []TokenEntry{{Time: now.Add(5 * time.Second), Count: 99}}, + DayStart: now.Add(5 * time.Second), + DayCount: 2, + }, + }, + } + data, err := yaml.Marshal(migrated) + require.NoError(t, err) + require.NoError(t, os.WriteFile(yamlPath, data, 0o644)) + + err = MigrateYAMLToSQLite(yamlPath, sqlitePath) + require.Error(t, err) + + store, err = newSQLiteStore(sqlitePath) + require.NoError(t, err) + defer store.close() + + quotas, err := store.loadQuotas() + require.NoError(t, err) + assert.Equal(t, originalQuotas, quotas) + + state, err := store.loadState() + require.NoError(t, err) + require.Contains(t, state, "old-model") + assert.Equal(t, originalState["old-model"].DayCount, state["old-model"].DayCount) + assert.Equal(t, originalState["old-model"].Tokens[0].Count, state["old-model"].Tokens[0].Count) + assert.NotContains(t, state, "new-model") +} + func TestMigrateYAMLToSQLitePreservesAllGeminiModels_Good(t *testing.T) { tmpDir := t.TempDir() yamlPath := filepath.Join(tmpDir, "full.yaml") @@ -650,6 +752,78 @@ func TestSQLiteEndToEnd_Good(t *testing.T) { assert.Equal(t, 5, custom.MaxRPM) } +func TestSQLiteLoadReplacesPersistedSnapshot_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "replace.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Quotas: map[string]ModelQuota{ + "model-a": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, + }, + }) + require.NoError(t, err) + + rl.RecordUsage("model-a", 10, 10) + require.NoError(t, rl.Persist()) + + delete(rl.Quotas, "model-a") + rl.Quotas["model-b"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20} + rl.Reset("") + rl.RecordUsage("model-b", 5, 5) + require.NoError(t, rl.Persist()) + require.NoError(t, rl.Close()) + + rl2, err := NewWithSQLiteConfig(dbPath, Config{ + Providers: []Provider{ProviderGemini}, + }) + require.NoError(t, err) + defer rl2.Close() + + rl2.State["stale-memory"] = &UsageStats{DayStart: time.Now(), DayCount: 99} + require.NoError(t, rl2.Load()) + + assert.NotContains(t, rl2.Quotas, "gemini-3-pro-preview") + assert.NotContains(t, rl2.Quotas, "model-a") + assert.Contains(t, rl2.Quotas, "model-b") + assert.NotContains(t, rl2.State, "stale-memory") + assert.NotContains(t, rl2.State, "model-a") + assert.Equal(t, 1, rl2.Stats("model-b").RPD) +} + +func TestSQLitePersistAtomic_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "persist-atomic.db") + rl, err := NewWithSQLiteConfig(dbPath, Config{ + Quotas: map[string]ModelQuota{ + "old-model": {MaxRPM: 1, MaxTPM: 100, MaxRPD: 10}, + }, + }) + require.NoError(t, err) + + rl.RecordUsage("old-model", 10, 10) + require.NoError(t, rl.Persist()) + + _, err = rl.sqlite.db.Exec(`CREATE TRIGGER fail_daily_persist BEFORE INSERT ON daily + BEGIN SELECT RAISE(ABORT, 'forced daily failure'); END`) + require.NoError(t, err) + + delete(rl.Quotas, "old-model") + rl.Quotas["new-model"] = ModelQuota{MaxRPM: 2, MaxTPM: 200, MaxRPD: 20} + rl.Reset("") + rl.RecordUsage("new-model", 50, 50) + + err = rl.Persist() + require.Error(t, err) + require.NoError(t, rl.Close()) + + rl2, err := NewWithSQLiteConfig(dbPath, Config{}) + require.NoError(t, err) + defer rl2.Close() + require.NoError(t, rl2.Load()) + + assert.Contains(t, rl2.Quotas, "old-model") + assert.NotContains(t, rl2.Quotas, "new-model") + assert.Equal(t, 1, rl2.Stats("old-model").RPD) + assert.Equal(t, 0, rl2.Stats("new-model").RPD) +} + // --- Phase 2: Benchmark --- func BenchmarkSQLitePersist(b *testing.B) { -- 2.45.3 From 22ab4edc86ac95b8510e771974c2a7a77c844c27 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 13:29:15 +0000 Subject: [PATCH 2/4] docs(ratelimit): map external attack vectors Co-Authored-By: Virgil --- docs/security-attack-vector-mapping.md | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 docs/security-attack-vector-mapping.md diff --git a/docs/security-attack-vector-mapping.md b/docs/security-attack-vector-mapping.md new file mode 100644 index 0000000..5768d0a --- /dev/null +++ b/docs/security-attack-vector-mapping.md @@ -0,0 +1,44 @@ +# Security Attack Vector Mapping + +Scope: external inputs that cross into this package from callers, persisted storage, or the network. This is a mapping only; it does not propose or apply fixes. + +Note: `CODEX.md` was not present anywhere under `/workspace` during this scan, so conventions were taken from `CLAUDE.md` and the existing repository layout. + +## Caller-Controlled API Inputs + +| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector | +| --- | --- | --- | --- | --- | --- | +| `NewWithConfig(cfg Config)` | `ratelimit.go:145` | Caller-controlled `cfg.FilePath`, `cfg.Providers`, `cfg.Quotas` | `cfg.FilePath` is stored in `rl.filePath` and later used by `Load()` / `Persist()`; `cfg.Providers` selects built-in profiles; `cfg.Quotas` is copied straight into `rl.Quotas` | Empty `FilePath` defaults to `~/.core/ratelimits.yaml`; unknown providers are ignored; `cfg.Backend` is not read here; quota values and model keys are copied verbatim | Untrusted `FilePath` can later steer YAML reads and writes to arbitrary local paths because the package uses unsandboxed `coreio.Local`; negative quota values act like "unlimited" because enforcement only checks `> 0`; very large maps or model strings can drive memory and disk growth | +| `(*RateLimiter).SetQuota(model string, quota ModelQuota)` | `ratelimit.go:183` | Caller-controlled `model` and `quota` | Direct write to `rl.Quotas[model]`; later consumed by `CanSend()`, `Stats()`, `Persist()`, and SQLite/YAML persistence | None beyond Go type checking | Negative quota fields disable enforcement for that dimension; extreme values can skew blocking behaviour; arbitrary model names expand the in-memory and persisted keyspace | +| `(*RateLimiter).AddProvider(provider Provider)` | `ratelimit.go:191` | Caller-controlled `provider` selector | Looks up `DefaultProfiles()[provider]` and overwrites matching entries in `rl.Quotas` via `maps.Copy` | Unknown providers are silently ignored; no authorisation or policy guard in this package | If a higher-level service exposes this call, an attacker can overwrite stricter runtime quotas for a provider's models with the shipped defaults and relax the intended rate-limit policy | +| `(*RateLimiter).BackgroundPrune(interval time.Duration)` | `ratelimit.go:328` | Caller-controlled `interval` | Passed to `time.NewTicker(interval)` and drives a background goroutine that repeatedly locks and prunes state | None | `interval <= 0` causes a panic; very small intervals can create CPU and lock-contention DoS; repeated calls without using the returned cancel function leak goroutines | +| `(*RateLimiter).CanSend(model string, estimatedTokens int)` | `ratelimit.go:350` | Caller-controlled `model` and `estimatedTokens` | `model` indexes `rl.Quotas` / `rl.State`; `estimatedTokens` is added to the current token total before the TPM comparison | Unknown models are allowed immediately; no non-negative or range checks on `estimatedTokens` | Passing an unconfigured model name bypasses throttling entirely; negative or overflowed token values can undercount the TPM check and permit oversend | +| `(*RateLimiter).RecordUsage(model string, promptTokens, outputTokens int)` | `ratelimit.go:396` | Caller-controlled `model`, `promptTokens`, `outputTokens` | Creates or updates `rl.State[model]`; stores `promptTokens + outputTokens` in the token window and increments `DayCount` | None | Arbitrary model names create unbounded state that will later persist to YAML/SQLite; negative or overflowed token totals poison accounting and can reduce future TPM totals below the real usage | +| `(*RateLimiter).WaitForCapacity(ctx context.Context, model string, tokens int)` | `ratelimit.go:414` | Caller-controlled `ctx`, `model`, `tokens` | Calls `CanSend(model, tokens)` once per second until capacity is available or `ctx.Done()` fires | No direct validation; relies on downstream `CanSend()` and caller-supplied context cancellation | Inherits the unknown-model and negative-token bypasses from `CanSend()`; repeated calls with long-lived contexts can accumulate goroutines and lock pressure | +| `(*RateLimiter).Reset(model string)` | `ratelimit.go:433` | Caller-controlled `model` | `model == ""` replaces the entire `rl.State` map; otherwise `delete(rl.State, model)` | Empty string is treated as a wildcard reset | If reachable by an untrusted actor, an empty string clears all rate-limit history and targeted resets erase throttling state for chosen models | +| `(*RateLimiter).Stats(model string)` | `ratelimit.go:484` | Caller-controlled `model` | Prunes `rl.State[model]`, reads `rl.Quotas[model]`, and returns a usage snapshot | None | If exposed through a service boundary, it discloses per-model quota ceilings and live usage counts that can help an attacker tune evasion or timing | +| `NewWithSQLite(dbPath string)` | `ratelimit.go:567` | Caller-controlled `dbPath` | Thin wrapper that forwards `dbPath` into `NewWithSQLiteConfig()` and then `newSQLiteStore()` | No additional validation in the wrapper | Untrusted `dbPath` can steer database creation/opening to unintended local filesystem locations, including companion `-wal` and `-shm` files | +| `NewWithSQLiteConfig(dbPath string, cfg Config)` | `ratelimit.go:576` | Caller-controlled `dbPath`, `cfg.Providers`, `cfg.Quotas` | `dbPath` goes straight to `newSQLiteStore()`; provider and quota inputs are copied into `rl.Quotas` exactly as in `NewWithConfig()` | `cfg.Backend` is ignored; unknown providers are ignored; no path, range, or size checks | Combines arbitrary database-path selection with the same quota poisoning risks as `NewWithConfig()` and `SetQuota()` | + +## Filesystem And YAML Inputs + +| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector | +| --- | --- | --- | --- | --- | --- | +| `(*RateLimiter).Load()` (YAML backend) | `ratelimit.go:210` | File bytes read from `rl.filePath` | `coreio.Local.Read(rl.filePath)` feeds `yaml.Unmarshal(..., rl)`, which overwrites `rl.Quotas` and `rl.State` | Missing file is ignored; YAML syntax and type mismatches return an error; no semantic checks on counts, limits, model names, or slice sizes | A malicious YAML file can inject negative quotas or counters to bypass enforcement, preload very large maps/slices for memory DoS, or replace the in-memory policy/state with attacker-chosen values | +| `MigrateYAMLToSQLite(yamlPath, sqlitePath string)` | `ratelimit.go:617` | Caller-controlled `yamlPath` and `sqlitePath`, plus YAML file bytes from `yamlPath` | Reads YAML with `coreio.Local.Read(yamlPath)`, unmarshals into a temporary `RateLimiter`, then opens `sqlitePath` and writes the imported quotas/state into SQLite | Read/open errors and YAML syntax/type errors are surfaced; no path restrictions and no semantic validation of imported quotas/state | Untrusted paths enable arbitrary local file reads and database creation/clobbering; attacker-controlled YAML can permanently seed the SQLite backend with quota-bypass values or oversized state | + +## SQLite Inputs + +| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector | +| --- | --- | --- | --- | --- | --- | +| `newSQLiteStore(dbPath string)` | `sqlite.go:20` | Caller-controlled database path passed from `NewWithSQLite*()` or `MigrateYAMLToSQLite()` | `sql.Open("sqlite", dbPath)` opens or creates the database, then applies PRAGMAs and creates tables and indexes | Only driver/open/PRAGMA/schema errors are checked; there is no allowlist, path normalisation, or sandboxing in this package | An attacker-chosen `dbPath` can redirect state into unintended files or directories and create matching SQLite sidecar files, which is useful for tampering, data placement, or storage-exhaustion attacks | +| `(*RateLimiter).Load()` (SQLite backend via `loadSQLite()`) | `ratelimit.go:223` | SQLite content reachable through the already-open `rl.sqlite` handle | `loadQuotas()` and `loadState()` results are copied into `rl.Quotas` and `rl.State`; loaded quotas override in-memory defaults | Only lower-level scan/query errors are checked; no semantic validation after load | A tampered SQLite database can override intended quotas/state, including negative limits and poisoned counters, before any enforcement call runs | +| `(*sqliteStore).loadQuotas()` | `sqlite.go:110` | Rows from the `quotas` table (`model`, `max_rpm`, `max_tpm`, `max_rpd`) | Scanned into `map[string]ModelQuota`, then copied into `rl.Quotas` by `loadSQLite()` | SQL scan and row iteration errors only; no range or length checks | Negative or extreme values disable or destabilise later quota enforcement; a large number of rows or large model strings can cause memory growth | +| `(*sqliteStore).loadState()` | `sqlite.go:194` | Rows from the `daily`, `requests`, and `tokens` tables | Scanned into `UsageStats` maps/slices and then copied into `rl.State` by `loadSQLite()` | SQL scan and row iteration errors only; no bounds, ordering, or semantic checks on timestamps and counts | Crafted counts or timestamps can poison later `CanSend()` / `Stats()` results; oversized tables can drive memory and CPU exhaustion during load | + +## Network Inputs + +| Function | File:Line | Input Source | What It Flows Into | Current Validation | Potential Attack Vector | +| --- | --- | --- | --- | --- | --- | +| `CountTokens(ctx, apiKey, model, text)` (request construction) | `ratelimit.go:650` | Caller-controlled `ctx`, `apiKey`, `model`, `text` | `model` is interpolated directly into the Google API URL path; `text` is marshalled into JSON request body; `apiKey` goes into `x-goog-api-key`; `ctx` governs request lifetime | JSON marshalling must succeed; `http.NewRequestWithContext()` rejects some malformed URLs; there is no path escaping, length check, or output-size cap on the request body | Untrusted prompt text is exfiltrated to a remote API; very large text can consume memory/bandwidth; unescaped model strings can alter the path or query on the fixed Google host; repeated calls burn external quota | +| `CountTokens(ctx, apiKey, model, text)` (response handling) | `ratelimit.go:681` | Remote HTTP status, response body, and JSON `totalTokens` value from `generativelanguage.googleapis.com` | Non-200 bodies are read fully with `io.ReadAll()` and embedded into the returned error; 200 responses are decoded into `result.TotalTokens` and returned to the caller | Checks for HTTP 200 and JSON decode errors only; no response-body size limit and no sanity check on `TotalTokens` | A very large error body can cause memory pressure and log/telemetry pollution; a negative or extreme `totalTokens` value would be returned unchanged and could poison downstream rate-limit accounting if the caller trusts it | -- 2.45.3 From 86dc04258a7d95a66d2c932964f3dab99f941b83 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 15:03:29 +0000 Subject: [PATCH 3/4] docs(ratelimit): extract api contract Co-Authored-By: Virgil --- docs/api-contract.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 docs/api-contract.md diff --git a/docs/api-contract.md b/docs/api-contract.md new file mode 100644 index 0000000..f5640cc --- /dev/null +++ b/docs/api-contract.md @@ -0,0 +1,35 @@ +# API Contract + +Test coverage is marked `yes` when the symbol is exercised by the existing test suite in `ratelimit_test.go`, `sqlite_test.go`, `error_test.go`, or `iter_test.go`. + +| Kind | Name | Signature | Description | Test coverage | +| --- | --- | --- | --- | --- | +| Type | `Provider` | `type Provider string` | Identifies an LLM provider used to select default quota profiles. | yes | +| Type | `ModelQuota` | `type ModelQuota struct { MaxRPM int; MaxTPM int; MaxRPD int }` | Declares per-model RPM, TPM, and RPD limits; `0` means unlimited. | yes | +| Type | `ProviderProfile` | `type ProviderProfile struct { Provider Provider; Models map[string]ModelQuota }` | Bundles a provider identifier with its default model quota map. | yes | +| Type | `Config` | `type Config struct { FilePath string; Backend string; Quotas map[string]ModelQuota; Providers []Provider }` | Configures limiter initialisation, persistence settings, explicit quotas, and provider defaults. | yes | +| Type | `TokenEntry` | `type TokenEntry struct { Time time.Time; Count int }` | Records a timestamped token-usage event. | yes | +| Type | `UsageStats` | `type UsageStats struct { Requests []time.Time; Tokens []TokenEntry; DayStart time.Time; DayCount int }` | Stores per-model sliding-window request and token history plus rolling daily usage state. | yes | +| Type | `RateLimiter` | `type RateLimiter struct { Quotas map[string]ModelQuota; State map[string]*UsageStats }` | Manages quotas, usage state, persistence, and concurrency across models. | yes | +| Type | `ModelStats` | `type ModelStats struct { RPM int; MaxRPM int; TPM int; MaxTPM int; RPD int; MaxRPD int; DayStart time.Time }` | Represents a snapshot of current usage and configured limits for a model. | yes | +| Function | `DefaultProfiles` | `func DefaultProfiles() map[Provider]ProviderProfile` | Returns the built-in quota profiles for the supported providers. | yes | +| Function | `New` | `func New() (*RateLimiter, error)` | Creates a new limiter with Gemini defaults for backward-compatible YAML-backed usage. | yes | +| Function | `NewWithConfig` | `func NewWithConfig(cfg Config) (*RateLimiter, error)` | Creates a YAML-backed limiter from explicit configuration, defaulting to Gemini when config is empty. | yes | +| Function | `NewWithSQLite` | `func NewWithSQLite(dbPath string) (*RateLimiter, error)` | Creates a SQLite-backed limiter with Gemini defaults and opens or creates the database. | yes | +| Function | `NewWithSQLiteConfig` | `func NewWithSQLiteConfig(dbPath string, cfg Config) (*RateLimiter, error)` | Creates a SQLite-backed limiter from explicit configuration and always uses SQLite persistence. | yes | +| Function | `MigrateYAMLToSQLite` | `func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error` | Migrates quotas and usage state from a YAML state file into a SQLite database. | yes | +| Function | `CountTokens` | `func CountTokens(ctx context.Context, apiKey, model, text string) (int, error)` | Calls the Gemini `countTokens` API for a prompt and returns the reported token count. | yes | +| Method | `SetQuota` | `func (rl *RateLimiter) SetQuota(model string, quota ModelQuota)` | Sets or replaces a model quota at runtime. | yes | +| Method | `AddProvider` | `func (rl *RateLimiter) AddProvider(provider Provider)` | Loads a built-in provider profile and overwrites any matching model quotas. | yes | +| Method | `Load` | `func (rl *RateLimiter) Load() error` | Loads quotas and usage state from YAML or SQLite into memory. | yes | +| Method | `Persist` | `func (rl *RateLimiter) Persist() error` | Persists a snapshot of quotas and usage state to YAML or SQLite. | yes | +| Method | `BackgroundPrune` | `func (rl *RateLimiter) BackgroundPrune(interval time.Duration) func()` | Starts periodic pruning of expired usage state and returns a stop function. | yes | +| Method | `CanSend` | `func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool` | Reports whether a request with the estimated token count fits within current limits. | yes | +| Method | `RecordUsage` | `func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)` | Records a successful request into the sliding-window and daily counters. | yes | +| Method | `WaitForCapacity` | `func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error` | Blocks until `CanSend` succeeds or the context is cancelled. | yes | +| Method | `Reset` | `func (rl *RateLimiter) Reset(model string)` | Clears usage state for one model or for all models when `model` is empty. | yes | +| Method | `Models` | `func (rl *RateLimiter) Models() iter.Seq[string]` | Returns a sorted iterator of all model names known from quotas or state. | yes | +| Method | `Iter` | `func (rl *RateLimiter) Iter() iter.Seq2[string, ModelStats]` | Returns a sorted iterator of model names paired with current stats snapshots. | yes | +| Method | `Stats` | `func (rl *RateLimiter) Stats(model string) ModelStats` | Returns current stats for a single model after pruning expired usage. | yes | +| Method | `AllStats` | `func (rl *RateLimiter) AllStats() map[string]ModelStats` | Returns stats snapshots for every tracked model. | yes | +| Method | `Close` | `func (rl *RateLimiter) Close() error` | Closes SQLite resources for SQLite-backed limiters and is a no-op for YAML-backed limiters. | yes | -- 2.45.3 From c5e2ed8b7e7c999c4812e7dd006ae9bc0e44a207 Mon Sep 17 00:00:00 2001 From: Virgil Date: Mon, 23 Mar 2026 15:06:08 +0000 Subject: [PATCH 4/4] docs(ratelimit): add convention drift report Co-Authored-By: Virgil --- docs/convention-drift-check-2026-03-23.md | 75 +++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 docs/convention-drift-check-2026-03-23.md diff --git a/docs/convention-drift-check-2026-03-23.md b/docs/convention-drift-check-2026-03-23.md new file mode 100644 index 0000000..0623b04 --- /dev/null +++ b/docs/convention-drift-check-2026-03-23.md @@ -0,0 +1,75 @@ + + +# Convention Drift Check + +Date: 2026-03-23 + +`CODEX.md` was not present anywhere under `/workspace`, so this check used +`CLAUDE.md`, `docs/development.md`, and the current tree. + +## stdlib -> core.* + +- `CLAUDE.md:65` still documents `forge.lthn.ai/core/go-io`, while the direct + dependency in `go.mod:6` is `dappco.re/go/core/io`. +- `CLAUDE.md:66` still documents `forge.lthn.ai/core/go-log`, while the direct + dependency in `go.mod:7` is `dappco.re/go/core/log`. +- `docs/development.md:175` still mandates `fmt.Errorf("ratelimit.Function: what: %w", err)`, + but the implementation and repo guidance use `coreerr.E(...)` + (`ratelimit.go:19`, `sqlite.go:8`, `CLAUDE.md:31`). +- `docs/development.md:195` omits the `core.*` direct dependencies entirely; + `go.mod:6-7` now declares both `dappco.re/go/core/io` and + `dappco.re/go/core/log`. +- `docs/index.md:102` likewise omits the `core.*` direct dependencies that are + present in `go.mod:6-7`. + +## UK English + +- `README.md:2` uses `License` in the badge alt text. +- `CONTRIBUTING.md:34` uses `License` as the section heading. + +## Missing Tests + +- `docs/development.md:150` says current coverage is `95.1%` and the floor is + `95%`, but `go test -coverprofile=/tmp/convention-cover.out ./...` currently + reports `94.4%`. +- `docs/history.md:28` still records coverage as `95.1%`; the current tree is + below that figure at `94.4%`. +- `docs/history.md:66` and `docs/development.md:157` say the + `os.UserHomeDir()` branch in `NewWithConfig()` is untestable, but + `error_test.go:648` now exercises that path. +- `ratelimit.go:202` (`Load`) is only `90.0%` covered. +- `ratelimit.go:242` (`Persist`) is only `90.0%` covered. +- `ratelimit.go:650` (`CountTokens`) is only `71.4%` covered. +- `sqlite.go:20` (`newSQLiteStore`) is only `64.3%` covered. +- `sqlite.go:110` (`loadQuotas`) is only `92.9%` covered. +- `sqlite.go:194` (`loadState`) is only `88.6%` covered. +- `ratelimit_test.go:746` starts a mock server for `CountTokens`, but the test + contains no assertion that exercises the success path. +- `iter_test.go:108` starts a second mock server for `CountTokens`, but again + does not exercise the mocked path. +- `error_test.go:42` defines `TestSQLiteInitErrors`, but the `WAL pragma failure` + subtest is still an empty placeholder. + +## SPDX Headers + +- `CLAUDE.md:1` is missing an SPDX header. +- `CONTRIBUTING.md:1` is missing an SPDX header. +- `README.md:1` is missing an SPDX header. +- `docs/architecture.md:1` is missing an SPDX header. +- `docs/development.md:1` is missing an SPDX header. +- `docs/history.md:1` is missing an SPDX header. +- `docs/index.md:1` is missing an SPDX header. +- `error_test.go:1` is missing an SPDX header. +- `go.mod:1` is missing an SPDX header. +- `iter_test.go:1` is missing an SPDX header. +- `ratelimit.go:1` is missing an SPDX header. +- `ratelimit_test.go:1` is missing an SPDX header. +- `sqlite.go:1` is missing an SPDX header. +- `sqlite_test.go:1` is missing an SPDX header. + +## Verification + +- `go test ./...` +- `go test -coverprofile=/tmp/convention-cover.out ./...` +- `go test -race ./...` +- `go vet ./...` -- 2.45.3