270 lines
7.8 KiB
Go
270 lines
7.8 KiB
Go
|
|
package ratelimit
|
||
|
|
|
||
|
|
import (
|
||
|
|
"database/sql"
|
||
|
|
"fmt"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
_ "modernc.org/sqlite"
|
||
|
|
)
|
||
|
|
|
||
|
|
// sqliteStore is the internal SQLite persistence layer for rate limit state.
|
||
|
|
type sqliteStore struct {
|
||
|
|
db *sql.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
// newSQLiteStore opens (or creates) a SQLite database at dbPath and initialises
|
||
|
|
// the schema. It follows the go-store pattern: single connection, WAL journal
|
||
|
|
// mode, and a 5-second busy timeout for contention handling.
|
||
|
|
func newSQLiteStore(dbPath string) (*sqliteStore, error) {
|
||
|
|
db, err := sql.Open("sqlite", dbPath)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.newSQLiteStore: open: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Single connection for PRAGMA consistency.
|
||
|
|
db.SetMaxOpenConns(1)
|
||
|
|
|
||
|
|
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
||
|
|
db.Close()
|
||
|
|
return nil, fmt.Errorf("ratelimit.newSQLiteStore: WAL: %w", err)
|
||
|
|
}
|
||
|
|
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
||
|
|
db.Close()
|
||
|
|
return nil, fmt.Errorf("ratelimit.newSQLiteStore: busy_timeout: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := createSchema(db); err != nil {
|
||
|
|
db.Close()
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
return &sqliteStore{db: db}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// createSchema creates the tables and indices if they do not already exist.
|
||
|
|
func createSchema(db *sql.DB) error {
|
||
|
|
stmts := []string{
|
||
|
|
`CREATE TABLE IF NOT EXISTS quotas (
|
||
|
|
model TEXT PRIMARY KEY,
|
||
|
|
max_rpm INTEGER NOT NULL DEFAULT 0,
|
||
|
|
max_tpm INTEGER NOT NULL DEFAULT 0,
|
||
|
|
max_rpd INTEGER NOT NULL DEFAULT 0
|
||
|
|
)`,
|
||
|
|
`CREATE TABLE IF NOT EXISTS requests (
|
||
|
|
model TEXT NOT NULL,
|
||
|
|
ts INTEGER NOT NULL
|
||
|
|
)`,
|
||
|
|
`CREATE TABLE IF NOT EXISTS tokens (
|
||
|
|
model TEXT NOT NULL,
|
||
|
|
ts INTEGER NOT NULL,
|
||
|
|
count INTEGER NOT NULL
|
||
|
|
)`,
|
||
|
|
`CREATE TABLE IF NOT EXISTS daily (
|
||
|
|
model TEXT PRIMARY KEY,
|
||
|
|
day_start INTEGER NOT NULL,
|
||
|
|
day_count INTEGER NOT NULL DEFAULT 0
|
||
|
|
)`,
|
||
|
|
`CREATE INDEX IF NOT EXISTS idx_requests_model_ts ON requests(model, ts)`,
|
||
|
|
`CREATE INDEX IF NOT EXISTS idx_tokens_model_ts ON tokens(model, ts)`,
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, stmt := range stmts {
|
||
|
|
if _, err := db.Exec(stmt); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.createSchema: %w", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// saveQuotas upserts all quotas into the quotas table.
|
||
|
|
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
|
||
|
|
tx, err := s.db.Begin()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveQuotas: begin: %w", err)
|
||
|
|
}
|
||
|
|
defer tx.Rollback()
|
||
|
|
|
||
|
|
stmt, err := tx.Prepare(`INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd)
|
||
|
|
VALUES (?, ?, ?, ?)
|
||
|
|
ON CONFLICT(model) DO UPDATE SET
|
||
|
|
max_rpm = excluded.max_rpm,
|
||
|
|
max_tpm = excluded.max_tpm,
|
||
|
|
max_rpd = excluded.max_rpd`)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveQuotas: prepare: %w", err)
|
||
|
|
}
|
||
|
|
defer stmt.Close()
|
||
|
|
|
||
|
|
for model, q := range quotas {
|
||
|
|
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveQuotas: exec %s: %w", model, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return tx.Commit()
|
||
|
|
}
|
||
|
|
|
||
|
|
// loadQuotas reads all rows from the quotas table.
|
||
|
|
func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
|
||
|
|
rows, err := s.db.Query("SELECT model, max_rpm, max_tpm, max_rpd FROM quotas")
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadQuotas: query: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
result := make(map[string]ModelQuota)
|
||
|
|
for rows.Next() {
|
||
|
|
var model string
|
||
|
|
var q ModelQuota
|
||
|
|
if err := rows.Scan(&model, &q.MaxRPM, &q.MaxTPM, &q.MaxRPD); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadQuotas: scan: %w", err)
|
||
|
|
}
|
||
|
|
result[model] = q
|
||
|
|
}
|
||
|
|
if err := rows.Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadQuotas: rows: %w", err)
|
||
|
|
}
|
||
|
|
return result, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// saveState writes all usage state to SQLite in a single transaction.
|
||
|
|
// It deletes existing rows and inserts fresh data for each model.
|
||
|
|
func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
|
||
|
|
tx, err := s.db.Begin()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: begin: %w", err)
|
||
|
|
}
|
||
|
|
defer tx.Rollback()
|
||
|
|
|
||
|
|
// Clear existing state.
|
||
|
|
if _, err := tx.Exec("DELETE FROM requests"); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: clear requests: %w", err)
|
||
|
|
}
|
||
|
|
if _, err := tx.Exec("DELETE FROM tokens"); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: clear tokens: %w", err)
|
||
|
|
}
|
||
|
|
if _, err := tx.Exec("DELETE FROM daily"); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: clear daily: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: prepare requests: %w", err)
|
||
|
|
}
|
||
|
|
defer reqStmt.Close()
|
||
|
|
|
||
|
|
tokStmt, err := tx.Prepare("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: prepare tokens: %w", err)
|
||
|
|
}
|
||
|
|
defer tokStmt.Close()
|
||
|
|
|
||
|
|
dayStmt, err := tx.Prepare("INSERT INTO daily (model, day_start, day_count) VALUES (?, ?, ?)")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: prepare daily: %w", err)
|
||
|
|
}
|
||
|
|
defer dayStmt.Close()
|
||
|
|
|
||
|
|
for model, stats := range state {
|
||
|
|
for _, t := range stats.Requests {
|
||
|
|
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: insert request %s: %w", model, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for _, te := range stats.Tokens {
|
||
|
|
if _, err := tokStmt.Exec(model, te.Time.UnixNano(), te.Count); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: insert token %s: %w", model, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if _, err := dayStmt.Exec(model, stats.DayStart.UnixNano(), stats.DayCount); err != nil {
|
||
|
|
return fmt.Errorf("ratelimit.saveState: insert daily %s: %w", model, err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return tx.Commit()
|
||
|
|
}
|
||
|
|
|
||
|
|
// loadState reconstructs the UsageStats map from SQLite tables.
|
||
|
|
func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
|
||
|
|
result := make(map[string]*UsageStats)
|
||
|
|
|
||
|
|
// Load daily counters first (these define which models have state).
|
||
|
|
rows, err := s.db.Query("SELECT model, day_start, day_count FROM daily")
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: query daily: %w", err)
|
||
|
|
}
|
||
|
|
defer rows.Close()
|
||
|
|
|
||
|
|
for rows.Next() {
|
||
|
|
var model string
|
||
|
|
var dayStartNano int64
|
||
|
|
var dayCount int
|
||
|
|
if err := rows.Scan(&model, &dayStartNano, &dayCount); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: scan daily: %w", err)
|
||
|
|
}
|
||
|
|
result[model] = &UsageStats{
|
||
|
|
DayStart: time.Unix(0, dayStartNano),
|
||
|
|
DayCount: dayCount,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if err := rows.Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: daily rows: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load requests.
|
||
|
|
reqRows, err := s.db.Query("SELECT model, ts FROM requests ORDER BY ts")
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: query requests: %w", err)
|
||
|
|
}
|
||
|
|
defer reqRows.Close()
|
||
|
|
|
||
|
|
for reqRows.Next() {
|
||
|
|
var model string
|
||
|
|
var tsNano int64
|
||
|
|
if err := reqRows.Scan(&model, &tsNano); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: scan requests: %w", err)
|
||
|
|
}
|
||
|
|
if _, ok := result[model]; !ok {
|
||
|
|
result[model] = &UsageStats{}
|
||
|
|
}
|
||
|
|
result[model].Requests = append(result[model].Requests, time.Unix(0, tsNano))
|
||
|
|
}
|
||
|
|
if err := reqRows.Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: request rows: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Load tokens.
|
||
|
|
tokRows, err := s.db.Query("SELECT model, ts, count FROM tokens ORDER BY ts")
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: query tokens: %w", err)
|
||
|
|
}
|
||
|
|
defer tokRows.Close()
|
||
|
|
|
||
|
|
for tokRows.Next() {
|
||
|
|
var model string
|
||
|
|
var tsNano int64
|
||
|
|
var count int
|
||
|
|
if err := tokRows.Scan(&model, &tsNano, &count); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: scan tokens: %w", err)
|
||
|
|
}
|
||
|
|
if _, ok := result[model]; !ok {
|
||
|
|
result[model] = &UsageStats{}
|
||
|
|
}
|
||
|
|
result[model].Tokens = append(result[model].Tokens, TokenEntry{
|
||
|
|
Time: time.Unix(0, tsNano),
|
||
|
|
Count: count,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
if err := tokRows.Err(); err != nil {
|
||
|
|
return nil, fmt.Errorf("ratelimit.loadState: token rows: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return result, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// close closes the underlying database connection.
|
||
|
|
func (s *sqliteStore) close() error {
|
||
|
|
return s.db.Close()
|
||
|
|
}
|