feat: extract go-ratelimit from core/go pkg/ratelimit
Token counting, model quotas, sliding window rate limiter. Zero external dependencies (stdlib only). Module: forge.lthn.ai/core/go-ratelimit Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
commit
fa1a6fc030
4 changed files with 587 additions and 0 deletions
19
CLAUDE.md
Normal file
19
CLAUDE.md
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# CLAUDE.md
|
||||
|
||||
## What This Is
|
||||
|
||||
Token counting, model quotas, and sliding window rate limiter. Module: `forge.lthn.ai/core/go-ratelimit`
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
go test ./... # Run all tests
|
||||
go test -v -run Name # Run single test
|
||||
```
|
||||
|
||||
## Coding Standards
|
||||
|
||||
- UK English
|
||||
- `go test ./...` must pass before commit
|
||||
- Conventional commits: `type(scope): description`
|
||||
- Co-Author: `Co-Authored-By: Virgil <virgil@lethean.io>`
|
||||
3
go.mod
Normal file
3
go.mod
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
module forge.lthn.ai/core/go-ratelimit
|
||||
|
||||
go 1.25.5
|
||||
389
ratelimit.go
Normal file
389
ratelimit.go
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ModelQuota defines the rate limits for a specific model.
|
||||
type ModelQuota struct {
|
||||
MaxRPM int `yaml:"max_rpm"` // Requests per minute
|
||||
MaxTPM int `yaml:"max_tpm"` // Tokens per minute
|
||||
MaxRPD int `yaml:"max_rpd"` // Requests per day (0 = unlimited)
|
||||
}
|
||||
|
||||
// TokenEntry records a token usage event.
|
||||
type TokenEntry struct {
|
||||
Time time.Time `yaml:"time"`
|
||||
Count int `yaml:"count"`
|
||||
}
|
||||
|
||||
// UsageStats tracks usage history for a model.
|
||||
type UsageStats struct {
|
||||
Requests []time.Time `yaml:"requests"` // Sliding window (1m)
|
||||
Tokens []TokenEntry `yaml:"tokens"` // Sliding window (1m)
|
||||
DayStart time.Time `yaml:"day_start"`
|
||||
DayCount int `yaml:"day_count"`
|
||||
}
|
||||
|
||||
// RateLimiter manages rate limits across multiple models.
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
Quotas map[string]ModelQuota `yaml:"quotas"`
|
||||
State map[string]*UsageStats `yaml:"state"`
|
||||
filePath string
|
||||
}
|
||||
|
||||
// New creates a new RateLimiter with default quotas.
|
||||
func New() (*RateLimiter, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rl := &RateLimiter{
|
||||
Quotas: make(map[string]ModelQuota),
|
||||
State: make(map[string]*UsageStats),
|
||||
filePath: filepath.Join(home, ".core", "ratelimits.yaml"),
|
||||
}
|
||||
|
||||
// Default quotas based on Tier 1 observations (Feb 2026)
|
||||
rl.Quotas["gemini-3-pro-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||
rl.Quotas["gemini-3-flash-preview"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||
rl.Quotas["gemini-2.5-pro"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 1000}
|
||||
rl.Quotas["gemini-2.0-flash"] = ModelQuota{MaxRPM: 150, MaxTPM: 1000000, MaxRPD: 0} // Unlimited RPD
|
||||
rl.Quotas["gemini-2.0-flash-lite"] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0} // Unlimited
|
||||
|
||||
return rl, nil
|
||||
}
|
||||
|
||||
// Load reads the state from disk.
|
||||
func (rl *RateLimiter) Load() error {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
data, err := os.ReadFile(rl.filePath)
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, rl)
|
||||
}
|
||||
|
||||
// Persist writes the state to disk.
|
||||
func (rl *RateLimiter) Persist() error {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
data, err := yaml.Marshal(rl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dir := filepath.Dir(rl.filePath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(rl.filePath, data, 0644)
|
||||
}
|
||||
|
||||
// prune removes entries older than the sliding window (1 minute).
|
||||
// Caller must hold lock.
|
||||
func (rl *RateLimiter) prune(model string) {
|
||||
stats, ok := rl.State[model]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
window := now.Add(-1 * time.Minute)
|
||||
|
||||
// Prune requests
|
||||
validReqs := 0
|
||||
for _, t := range stats.Requests {
|
||||
if t.After(window) {
|
||||
stats.Requests[validReqs] = t
|
||||
validReqs++
|
||||
}
|
||||
}
|
||||
stats.Requests = stats.Requests[:validReqs]
|
||||
|
||||
// Prune tokens
|
||||
validTokens := 0
|
||||
for _, t := range stats.Tokens {
|
||||
if t.Time.After(window) {
|
||||
stats.Tokens[validTokens] = t
|
||||
validTokens++
|
||||
}
|
||||
}
|
||||
stats.Tokens = stats.Tokens[:validTokens]
|
||||
|
||||
// Reset daily counter if day has passed
|
||||
if now.Sub(stats.DayStart) >= 24*time.Hour {
|
||||
stats.DayStart = now
|
||||
stats.DayCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
// CanSend checks if a request can be sent without violating limits.
|
||||
func (rl *RateLimiter) CanSend(model string, estimatedTokens int) bool {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
quota, ok := rl.Quotas[model]
|
||||
if !ok {
|
||||
return true // Unknown models are allowed
|
||||
}
|
||||
|
||||
// Unlimited check
|
||||
if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Ensure state exists
|
||||
if _, ok := rl.State[model]; !ok {
|
||||
rl.State[model] = &UsageStats{
|
||||
DayStart: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
rl.prune(model)
|
||||
stats := rl.State[model]
|
||||
|
||||
// Check RPD
|
||||
if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check RPM
|
||||
if quota.MaxRPM > 0 && len(stats.Requests) >= quota.MaxRPM {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check TPM
|
||||
if quota.MaxTPM > 0 {
|
||||
currentTokens := 0
|
||||
for _, t := range stats.Tokens {
|
||||
currentTokens += t.Count
|
||||
}
|
||||
if currentTokens+estimatedTokens > quota.MaxTPM {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// RecordUsage records a successful API call.
|
||||
func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if _, ok := rl.State[model]; !ok {
|
||||
rl.State[model] = &UsageStats{
|
||||
DayStart: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
stats := rl.State[model]
|
||||
now := time.Now()
|
||||
|
||||
stats.Requests = append(stats.Requests, now)
|
||||
stats.Tokens = append(stats.Tokens, TokenEntry{Time: now, Count: promptTokens + outputTokens})
|
||||
stats.DayCount++
|
||||
}
|
||||
|
||||
// WaitForCapacity blocks until capacity is available or context is cancelled.
|
||||
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
if rl.CanSend(model, tokens) {
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
// check again
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears stats for a model (or all if model is empty).
|
||||
func (rl *RateLimiter) Reset(model string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if model == "" {
|
||||
rl.State = make(map[string]*UsageStats)
|
||||
} else {
|
||||
delete(rl.State, model)
|
||||
}
|
||||
}
|
||||
|
||||
// ModelStats represents a snapshot of usage.
|
||||
type ModelStats struct {
|
||||
RPM int
|
||||
MaxRPM int
|
||||
TPM int
|
||||
MaxTPM int
|
||||
RPD int
|
||||
MaxRPD int
|
||||
DayStart time.Time
|
||||
}
|
||||
|
||||
// Stats returns current stats for a model.
|
||||
func (rl *RateLimiter) Stats(model string) ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
rl.prune(model)
|
||||
|
||||
stats := ModelStats{}
|
||||
quota, ok := rl.Quotas[model]
|
||||
if ok {
|
||||
stats.MaxRPM = quota.MaxRPM
|
||||
stats.MaxTPM = quota.MaxTPM
|
||||
stats.MaxRPD = quota.MaxRPD
|
||||
}
|
||||
|
||||
if s, ok := rl.State[model]; ok {
|
||||
stats.RPM = len(s.Requests)
|
||||
stats.RPD = s.DayCount
|
||||
stats.DayStart = s.DayStart
|
||||
for _, t := range s.Tokens {
|
||||
stats.TPM += t.Count
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// AllStats returns stats for all tracked models.
|
||||
func (rl *RateLimiter) AllStats() map[string]ModelStats {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
result := make(map[string]ModelStats)
|
||||
|
||||
// Collect all model names
|
||||
for m := range rl.Quotas {
|
||||
result[m] = ModelStats{}
|
||||
}
|
||||
for m := range rl.State {
|
||||
result[m] = ModelStats{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
window := now.Add(-1 * time.Minute)
|
||||
|
||||
for m := range result {
|
||||
// Prune inline
|
||||
if s, ok := rl.State[m]; ok {
|
||||
validReqs := 0
|
||||
for _, t := range s.Requests {
|
||||
if t.After(window) {
|
||||
s.Requests[validReqs] = t
|
||||
validReqs++
|
||||
}
|
||||
}
|
||||
s.Requests = s.Requests[:validReqs]
|
||||
|
||||
validTokens := 0
|
||||
for _, t := range s.Tokens {
|
||||
if t.Time.After(window) {
|
||||
s.Tokens[validTokens] = t
|
||||
validTokens++
|
||||
}
|
||||
}
|
||||
s.Tokens = s.Tokens[:validTokens]
|
||||
|
||||
if now.Sub(s.DayStart) >= 24*time.Hour {
|
||||
s.DayStart = now
|
||||
s.DayCount = 0
|
||||
}
|
||||
}
|
||||
|
||||
ms := ModelStats{}
|
||||
if q, ok := rl.Quotas[m]; ok {
|
||||
ms.MaxRPM = q.MaxRPM
|
||||
ms.MaxTPM = q.MaxTPM
|
||||
ms.MaxRPD = q.MaxRPD
|
||||
}
|
||||
if s, ok := rl.State[m]; ok {
|
||||
ms.RPM = len(s.Requests)
|
||||
ms.RPD = s.DayCount
|
||||
ms.DayStart = s.DayStart
|
||||
for _, t := range s.Tokens {
|
||||
ms.TPM += t.Count
|
||||
}
|
||||
}
|
||||
result[m] = ms
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// CountTokens calls the Google API to count tokens for a prompt.
|
||||
func CountTokens(apiKey, model, text string) (int, error) {
|
||||
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:countTokens", model)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"contents": []any{
|
||||
map[string]any{
|
||||
"parts": []any{
|
||||
map[string]string{"text": text},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return 0, fmt.Errorf("API error %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return result.TotalTokens, nil
|
||||
}
|
||||
176
ratelimit_test.go
Normal file
176
ratelimit_test.go
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCanSend_Good(t *testing.T) {
|
||||
rl, _ := New()
|
||||
rl.filePath = filepath.Join(t.TempDir(), "ratelimits.yaml")
|
||||
|
||||
model := "test-model"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000, MaxRPD: 100}
|
||||
|
||||
if !rl.CanSend(model, 100) {
|
||||
t.Errorf("Expected CanSend to return true for fresh state")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanSend_RPMExceeded_Bad(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "test-rpm"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 2, MaxTPM: 1000000, MaxRPD: 100}
|
||||
|
||||
rl.RecordUsage(model, 10, 10)
|
||||
rl.RecordUsage(model, 10, 10)
|
||||
|
||||
if rl.CanSend(model, 10) {
|
||||
t.Errorf("Expected CanSend to return false after exceeding RPM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanSend_TPMExceeded_Bad(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "test-tpm"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 100, MaxRPD: 100}
|
||||
|
||||
rl.RecordUsage(model, 50, 40) // 90 tokens used
|
||||
|
||||
if rl.CanSend(model, 20) { // 90 + 20 = 110 > 100
|
||||
t.Errorf("Expected CanSend to return false when estimated tokens exceed TPM")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanSend_RPDExceeded_Bad(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "test-rpd"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 10, MaxTPM: 1000000, MaxRPD: 2}
|
||||
|
||||
rl.RecordUsage(model, 10, 10)
|
||||
rl.RecordUsage(model, 10, 10)
|
||||
|
||||
if rl.CanSend(model, 10) {
|
||||
t.Errorf("Expected CanSend to return false after exceeding RPD")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanSend_UnlimitedModel_Good(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "test-unlimited"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 0, MaxTPM: 0, MaxRPD: 0}
|
||||
|
||||
// Should always be allowed
|
||||
for i := 0; i < 1000; i++ {
|
||||
rl.RecordUsage(model, 100, 100)
|
||||
}
|
||||
if !rl.CanSend(model, 999999) {
|
||||
t.Errorf("Expected unlimited model to always allow sends")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordUsage_PrunesOldEntries_Good(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "test-prune"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 1000000, MaxRPD: 100}
|
||||
|
||||
// Manually inject old data
|
||||
oldTime := time.Now().Add(-2 * time.Minute)
|
||||
rl.State[model] = &UsageStats{
|
||||
Requests: []time.Time{oldTime, oldTime, oldTime},
|
||||
Tokens: []TokenEntry{
|
||||
{Time: oldTime, Count: 100},
|
||||
{Time: oldTime, Count: 100},
|
||||
},
|
||||
DayStart: time.Now(),
|
||||
}
|
||||
|
||||
// CanSend triggers prune
|
||||
if !rl.CanSend(model, 10) {
|
||||
t.Errorf("Expected CanSend to return true after pruning old entries")
|
||||
}
|
||||
|
||||
stats := rl.State[model]
|
||||
if len(stats.Requests) != 0 {
|
||||
t.Errorf("Expected 0 requests after pruning old entries, got %d", len(stats.Requests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistAndLoad_Good(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "ratelimits.yaml")
|
||||
|
||||
rl1, _ := New()
|
||||
rl1.filePath = path
|
||||
model := "persist-test"
|
||||
rl1.Quotas[model] = ModelQuota{MaxRPM: 50, MaxTPM: 5000, MaxRPD: 500}
|
||||
rl1.RecordUsage(model, 100, 100)
|
||||
|
||||
if err := rl1.Persist(); err != nil {
|
||||
t.Fatalf("Persist failed: %v", err)
|
||||
}
|
||||
|
||||
rl2, _ := New()
|
||||
rl2.filePath = path
|
||||
if err := rl2.Load(); err != nil {
|
||||
t.Fatalf("Load failed: %v", err)
|
||||
}
|
||||
|
||||
stats := rl2.Stats(model)
|
||||
if stats.RPM != 1 {
|
||||
t.Errorf("Expected RPM 1 after load, got %d", stats.RPM)
|
||||
}
|
||||
if stats.TPM != 200 {
|
||||
t.Errorf("Expected TPM 200 after load, got %d", stats.TPM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitForCapacity_Ugly(t *testing.T) {
|
||||
rl, _ := New()
|
||||
model := "wait-test"
|
||||
rl.Quotas[model] = ModelQuota{MaxRPM: 1, MaxTPM: 1000000, MaxRPD: 100}
|
||||
|
||||
rl.RecordUsage(model, 10, 10) // Use up the 1 RPM
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := rl.WaitForCapacity(ctx, model, 10)
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("Expected DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultQuotas_Good(t *testing.T) {
|
||||
rl, _ := New()
|
||||
expected := []string{
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-2.0-flash",
|
||||
}
|
||||
for _, m := range expected {
|
||||
if _, ok := rl.Quotas[m]; !ok {
|
||||
t.Errorf("Expected default quota for %s", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllStats_Good(t *testing.T) {
|
||||
rl, _ := New()
|
||||
rl.RecordUsage("gemini-3-pro-preview", 1000, 500)
|
||||
|
||||
all := rl.AllStats()
|
||||
if len(all) < 5 {
|
||||
t.Errorf("Expected at least 5 models in AllStats, got %d", len(all))
|
||||
}
|
||||
|
||||
pro := all["gemini-3-pro-preview"]
|
||||
if pro.RPM != 1 {
|
||||
t.Errorf("Expected RPM 1 for pro, got %d", pro.RPM)
|
||||
}
|
||||
if pro.TPM != 1500 {
|
||||
t.Errorf("Expected TPM 1500 for pro, got %d", pro.TPM)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue