refactor(ratelimit): upgrade to core v0.8.0-alpha.1

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-03-26 15:41:11 +00:00
parent bd6c6e5136
commit 36cc0a4750
4 changed files with 113 additions and 82 deletions

8
go.mod
View file

@ -3,22 +3,20 @@ module forge.lthn.ai/core/go-ratelimit
go 1.26.0 go 1.26.0
require ( require (
dappco.re/go/core/io v0.2.0 dappco.re/go/core v0.8.0-alpha.1
dappco.re/go/core/log v0.1.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v1.47.0 modernc.org/sqlite v1.47.0
) )
require ( require (
forge.lthn.ai/core/go-log v0.0.4 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/tools v0.43.0 // indirect
modernc.org/libc v1.70.0 // indirect modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect

9
go.sum
View file

@ -1,9 +1,6 @@
dappco.re/go/core/io v0.2.0 h1:zuudgIiTsQQ5ipVt97saWdGLROovbEB/zdVyy9/l+I4= dappco.re/go/core v0.8.0-alpha.1 h1:gj7+Scv+L63Z7wMxbJYHhaRFkHJo2u4MMPuUSv/Dhtk=
dappco.re/go/core/io v0.2.0/go.mod h1:1QnQV6X9LNgFKfm8SkOtR9LLaj3bDcsOIeJOOyjbL5E= dappco.re/go/core v0.8.0-alpha.1/go.mod h1:f2/tBZ3+3IqDrg2F5F598llv0nmb/4gJVCFzM5geE4A=
dappco.re/go/core/log v0.1.0 h1:pa71Vq2TD2aoEUQWFKwNcaJ3GBY8HbaNGqtE688Unyc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
dappco.re/go/core/log v0.1.0/go.mod h1:Nkqb8gsXhZAO8VLpx7B8i1iAmohhzqA20b9Zr8VUcJs=
forge.lthn.ai/core/go-log v0.0.4 h1:KTuCEPgFmuM8KJfnyQ8vPOU1Jg654W74h8IJvfQMfv0=
forge.lthn.ai/core/go-log v0.0.4/go.mod h1:r14MXKOD3LF/sI8XUJQhRk/SZHBE7jAFVuCfgkXoZPw=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=

View file

@ -1,24 +1,18 @@
package ratelimit package ratelimit
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt"
"io" "io"
"iter" "iter"
"maps" "maps"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path/filepath"
"slices" "slices"
"strings"
"sync" "sync"
"time" "time"
coreio "dappco.re/go/core/io" core "dappco.re/go/core"
coreerr "dappco.re/go/core/log"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -170,8 +164,8 @@ func NewWithConfig(cfg Config) (*RateLimiter, error) {
if backend == backendSQLite { if backend == backendSQLite {
if cfg.FilePath == "" { if cfg.FilePath == "" {
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { if err := ensureDir(core.PathDir(filePath)); err != nil {
return nil, coreerr.E("ratelimit.NewWithConfig", "mkdir", err) return nil, core.E("ratelimit.NewWithConfig", "mkdir", err)
} }
} }
return NewWithSQLiteConfig(filePath, cfg) return NewWithSQLiteConfig(filePath, cfg)
@ -210,7 +204,7 @@ func (rl *RateLimiter) Load() error {
return rl.loadSQLite() return rl.loadSQLite()
} }
content, err := coreio.Local.Read(rl.filePath) content, err := readLocalFile(rl.filePath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
return nil return nil
} }
@ -265,7 +259,7 @@ func (rl *RateLimiter) Persist() error {
if sqlite != nil { if sqlite != nil {
if err := sqlite.saveSnapshot(quotas, state); err != nil { if err := sqlite.saveSnapshot(quotas, state); err != nil {
return coreerr.E("ratelimit.Persist", "sqlite snapshot", err) return core.E("ratelimit.Persist", "sqlite snapshot", err)
} }
return nil return nil
} }
@ -280,11 +274,11 @@ func (rl *RateLimiter) Persist() error {
State: state, State: state,
}) })
if err != nil { if err != nil {
return coreerr.E("ratelimit.Persist", "marshal", err) return core.E("ratelimit.Persist", "marshal", err)
} }
if err := coreio.Local.Write(filePath, string(data)); err != nil { if err := writeLocalFile(filePath, string(data)); err != nil {
return coreerr.E("ratelimit.Persist", "write", err) return core.E("ratelimit.Persist", "write", err)
} }
return nil return nil
} }
@ -428,7 +422,7 @@ func (rl *RateLimiter) RecordUsage(model string, promptTokens, outputTokens int)
// WaitForCapacity blocks until capacity is available or context is cancelled. // WaitForCapacity blocks until capacity is available or context is cancelled.
func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error { func (rl *RateLimiter) WaitForCapacity(ctx context.Context, model string, tokens int) error {
if tokens < 0 { if tokens < 0 {
return coreerr.E("ratelimit.WaitForCapacity", "negative tokens", nil) return core.E("ratelimit.WaitForCapacity", "negative tokens", nil)
} }
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
@ -600,14 +594,14 @@ func (rl *RateLimiter) Close() error {
// database is created if it does not exist. // database is created if it does not exist.
func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error { func MigrateYAMLToSQLite(yamlPath, sqlitePath string) error {
// Load from YAML. // Load from YAML.
content, err := coreio.Local.Read(yamlPath) content, err := readLocalFile(yamlPath)
if err != nil { if err != nil {
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "read", err) return core.E("ratelimit.MigrateYAMLToSQLite", "read", err)
} }
var rl RateLimiter var rl RateLimiter
if err := yaml.Unmarshal([]byte(content), &rl); err != nil { if err := yaml.Unmarshal([]byte(content), &rl); err != nil {
return coreerr.E("ratelimit.MigrateYAMLToSQLite", "unmarshal", err) return core.E("ratelimit.MigrateYAMLToSQLite", "unmarshal", err)
} }
// Write to SQLite. // Write to SQLite.
@ -631,7 +625,7 @@ func CountTokens(ctx context.Context, apiKey, model, text string) (int, error) {
func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) { func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, apiKey, model, text string) (int, error) {
requestURL, err := countTokensURL(baseURL, model) requestURL, err := countTokensURL(baseURL, model)
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "build url", err) return 0, core.E("ratelimit.CountTokens", "build url", err)
} }
reqBody := map[string]any{ reqBody := map[string]any{
@ -644,14 +638,14 @@ func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, ap
}, },
} }
jsonBody, err := json.Marshal(reqBody) jsonBody := core.JSONMarshal(reqBody)
if err != nil { if !jsonBody.OK {
return 0, coreerr.E("ratelimit.CountTokens", "marshal request", err) return 0, core.E("ratelimit.CountTokens", "marshal request", resultError(jsonBody))
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(jsonBody)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, core.NewReader(string(jsonBody.Value.([]byte))))
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "new request", err) return 0, core.E("ratelimit.CountTokens", "new request", err)
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("x-goog-api-key", apiKey) req.Header.Set("x-goog-api-key", apiKey)
@ -662,23 +656,29 @@ func countTokensWithClient(ctx context.Context, client *http.Client, baseURL, ap
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "do request", err) return 0, core.E("ratelimit.CountTokens", "do request", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit) body, err := readLimitedBody(resp.Body, countTokensErrorBodyLimit)
if err != nil { if err != nil {
return 0, coreerr.E("ratelimit.CountTokens", "read error body", err) return 0, core.E("ratelimit.CountTokens", "read error body", err)
} }
return 0, coreerr.E("ratelimit.CountTokens", fmt.Sprintf("api error status %d: %s", resp.StatusCode, body), nil) return 0, core.E("ratelimit.CountTokens", core.Sprintf("api error status %d: %s", resp.StatusCode, body), nil)
}
body, err := readLimitedBody(resp.Body, countTokensSuccessBodyLimit)
if err != nil {
return 0, core.E("ratelimit.CountTokens", "decode response", err)
} }
var result struct { var result struct {
TotalTokens int `json:"totalTokens"` TotalTokens int `json:"totalTokens"`
} }
if err := json.NewDecoder(io.LimitReader(resp.Body, countTokensSuccessBodyLimit)).Decode(&result); err != nil { decode := core.JSONUnmarshalString(body, &result)
return 0, coreerr.E("ratelimit.CountTokens", "decode response", err) if !decode.OK {
return 0, core.E("ratelimit.CountTokens", "decode response", resultError(decode))
} }
return result.TotalTokens, nil return result.TotalTokens, nil
@ -711,13 +711,13 @@ func applyConfig(rl *RateLimiter, cfg Config) {
} }
func normaliseBackend(backend string) (string, error) { func normaliseBackend(backend string) (string, error) {
switch strings.ToLower(strings.TrimSpace(backend)) { switch core.Lower(core.Trim(backend)) {
case "", backendYAML: case "", backendYAML:
return backendYAML, nil return backendYAML, nil
case backendSQLite: case backendSQLite:
return backendSQLite, nil return backendSQLite, nil
default: default:
return "", coreerr.E("ratelimit.NewWithConfig", fmt.Sprintf("unknown backend %q", backend), nil) return "", core.E("ratelimit.NewWithConfig", core.Sprintf("unknown backend %q", backend), nil)
} }
} }
@ -732,7 +732,7 @@ func defaultStatePath(backend string) (string, error) {
fileName = defaultSQLiteStateFile fileName = defaultSQLiteStateFile
} }
return filepath.Join(home, defaultStateDirName, fileName), nil return core.Path(home, defaultStateDirName, fileName), nil
} }
func safeTokenSum(a, b int) int { func safeTokenSum(a, b int) int {
@ -761,8 +761,8 @@ func safeTokenTotal(tokens []TokenEntry) int {
} }
func countTokensURL(baseURL, model string) (string, error) { func countTokensURL(baseURL, model string) (string, error) {
if strings.TrimSpace(model) == "" { if core.Trim(model) == "" {
return "", fmt.Errorf("empty model") return "", core.NewError("empty model")
} }
parsed, err := url.Parse(baseURL) parsed, err := url.Parse(baseURL)
@ -770,10 +770,10 @@ func countTokensURL(baseURL, model string) (string, error) {
return "", err return "", err
} }
if parsed.Scheme == "" || parsed.Host == "" { if parsed.Scheme == "" || parsed.Host == "" {
return "", fmt.Errorf("invalid base url") return "", core.NewError("invalid base url")
} }
return strings.TrimRight(parsed.String(), "/") + "/v1beta/models/" + url.PathEscape(model) + ":countTokens", nil return core.Concat(core.TrimSuffix(parsed.String(), "/"), "/v1beta/models/", url.PathEscape(model), ":countTokens"), nil
} }
func readLimitedBody(r io.Reader, limit int64) (string, error) { func readLimitedBody(r io.Reader, limit int64) (string, error) {
@ -793,3 +793,40 @@ func readLimitedBody(r io.Reader, limit int64) (string, error) {
} }
return result, nil return result, nil
} }
func readLocalFile(path string) (string, error) {
var fs core.Fs
result := fs.Read(path)
if !result.OK {
return "", resultError(result)
}
content, ok := result.Value.(string)
if !ok {
return "", core.NewError("read returned non-string")
}
return content, nil
}
func writeLocalFile(path, content string) error {
var fs core.Fs
return resultError(fs.Write(path, content))
}
func ensureDir(path string) error {
var fs core.Fs
return resultError(fs.EnsureDir(path))
}
func resultError(result core.Result) error {
if result.OK {
return nil
}
if err, ok := result.Value.(error); ok {
return err
}
if result.Value == nil {
return nil
}
return core.NewError(core.Sprint(result.Value))
}

View file

@ -2,10 +2,9 @@ package ratelimit
import ( import (
"database/sql" "database/sql"
"fmt"
"time" "time"
coreerr "dappco.re/go/core/log" core "dappco.re/go/core"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
@ -20,7 +19,7 @@ type sqliteStore struct {
func newSQLiteStore(dbPath string) (*sqliteStore, error) { func newSQLiteStore(dbPath string) (*sqliteStore, error) {
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return nil, coreerr.E("ratelimit.newSQLiteStore", "open", err) return nil, core.E("ratelimit.newSQLiteStore", "open", err)
} }
// Single connection for PRAGMA consistency. // Single connection for PRAGMA consistency.
@ -28,11 +27,11 @@ func newSQLiteStore(dbPath string) (*sqliteStore, error) {
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close() db.Close()
return nil, coreerr.E("ratelimit.newSQLiteStore", "WAL", err) return nil, core.E("ratelimit.newSQLiteStore", "WAL", err)
} }
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
db.Close() db.Close()
return nil, coreerr.E("ratelimit.newSQLiteStore", "busy_timeout", err) return nil, core.E("ratelimit.newSQLiteStore", "busy_timeout", err)
} }
if err := createSchema(db); err != nil { if err := createSchema(db); err != nil {
@ -72,7 +71,7 @@ func createSchema(db *sql.DB) error {
for _, stmt := range stmts { for _, stmt := range stmts {
if _, err := db.Exec(stmt); err != nil { if _, err := db.Exec(stmt); err != nil {
return coreerr.E("ratelimit.createSchema", "exec", err) return core.E("ratelimit.createSchema", "exec", err)
} }
} }
return nil return nil
@ -82,12 +81,12 @@ func createSchema(db *sql.DB) error {
func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error { func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveQuotas", "begin", err) return core.E("ratelimit.saveQuotas", "begin", err)
} }
defer tx.Rollback() defer tx.Rollback()
if _, err := tx.Exec("DELETE FROM quotas"); err != nil { if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
return coreerr.E("ratelimit.saveQuotas", "clear", err) return core.E("ratelimit.saveQuotas", "clear", err)
} }
if err := insertQuotas(tx, quotas); err != nil { if err := insertQuotas(tx, quotas); err != nil {
@ -101,7 +100,7 @@ func (s *sqliteStore) saveQuotas(quotas map[string]ModelQuota) error {
func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) { func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
rows, err := s.db.Query("SELECT model, max_rpm, max_tpm, max_rpd FROM quotas") rows, err := s.db.Query("SELECT model, max_rpm, max_tpm, max_rpd FROM quotas")
if err != nil { if err != nil {
return nil, coreerr.E("ratelimit.loadQuotas", "query", err) return nil, core.E("ratelimit.loadQuotas", "query", err)
} }
defer rows.Close() defer rows.Close()
@ -110,12 +109,12 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
var model string var model string
var q ModelQuota var q ModelQuota
if err := rows.Scan(&model, &q.MaxRPM, &q.MaxTPM, &q.MaxRPD); err != nil { if err := rows.Scan(&model, &q.MaxRPM, &q.MaxTPM, &q.MaxRPD); err != nil {
return nil, coreerr.E("ratelimit.loadQuotas", "scan", err) return nil, core.E("ratelimit.loadQuotas", "scan", err)
} }
result[model] = q result[model] = q
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, coreerr.E("ratelimit.loadQuotas", "rows", err) return nil, core.E("ratelimit.loadQuotas", "rows", err)
} }
return result, nil return result, nil
} }
@ -124,7 +123,7 @@ func (s *sqliteStore) loadQuotas() (map[string]ModelQuota, error) {
func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error { func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[string]*UsageStats) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveSnapshot", "begin", err) return core.E("ratelimit.saveSnapshot", "begin", err)
} }
defer tx.Rollback() defer tx.Rollback()
@ -148,7 +147,7 @@ func (s *sqliteStore) saveSnapshot(quotas map[string]ModelQuota, state map[strin
func (s *sqliteStore) saveState(state map[string]*UsageStats) error { func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveState", "begin", err) return core.E("ratelimit.saveState", "begin", err)
} }
defer tx.Rollback() defer tx.Rollback()
@ -166,17 +165,17 @@ func (s *sqliteStore) saveState(state map[string]*UsageStats) error {
func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error { func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
if includeQuotas { if includeQuotas {
if _, err := tx.Exec("DELETE FROM quotas"); err != nil { if _, err := tx.Exec("DELETE FROM quotas"); err != nil {
return coreerr.E("ratelimit.saveSnapshot", "clear quotas", err) return core.E("ratelimit.saveSnapshot", "clear quotas", err)
} }
} }
if _, err := tx.Exec("DELETE FROM requests"); err != nil { if _, err := tx.Exec("DELETE FROM requests"); err != nil {
return coreerr.E("ratelimit.saveState", "clear requests", err) return core.E("ratelimit.saveState", "clear requests", err)
} }
if _, err := tx.Exec("DELETE FROM tokens"); err != nil { if _, err := tx.Exec("DELETE FROM tokens"); err != nil {
return coreerr.E("ratelimit.saveState", "clear tokens", err) return core.E("ratelimit.saveState", "clear tokens", err)
} }
if _, err := tx.Exec("DELETE FROM daily"); err != nil { if _, err := tx.Exec("DELETE FROM daily"); err != nil {
return coreerr.E("ratelimit.saveState", "clear daily", err) return core.E("ratelimit.saveState", "clear daily", err)
} }
return nil return nil
} }
@ -184,13 +183,13 @@ func clearSnapshotTables(tx *sql.Tx, includeQuotas bool) error {
func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error { func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)") stmt, err := tx.Prepare("INSERT INTO quotas (model, max_rpm, max_tpm, max_rpd) VALUES (?, ?, ?, ?)")
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveQuotas", "prepare", err) return core.E("ratelimit.saveQuotas", "prepare", err)
} }
defer stmt.Close() defer stmt.Close()
for model, q := range quotas { for model, q := range quotas {
if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil { if _, err := stmt.Exec(model, q.MaxRPM, q.MaxTPM, q.MaxRPD); err != nil {
return coreerr.E("ratelimit.saveQuotas", fmt.Sprintf("exec %s", model), err) return core.E("ratelimit.saveQuotas", core.Concat("exec ", model), err)
} }
} }
return nil return nil
@ -199,19 +198,19 @@ func insertQuotas(tx *sql.Tx, quotas map[string]ModelQuota) error {
func insertState(tx *sql.Tx, state map[string]*UsageStats) error { func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)") reqStmt, err := tx.Prepare("INSERT INTO requests (model, ts) VALUES (?, ?)")
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveState", "prepare requests", err) return core.E("ratelimit.saveState", "prepare requests", err)
} }
defer reqStmt.Close() defer reqStmt.Close()
tokStmt, err := tx.Prepare("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)") tokStmt, err := tx.Prepare("INSERT INTO tokens (model, ts, count) VALUES (?, ?, ?)")
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveState", "prepare tokens", err) return core.E("ratelimit.saveState", "prepare tokens", err)
} }
defer tokStmt.Close() defer tokStmt.Close()
dayStmt, err := tx.Prepare("INSERT INTO daily (model, day_start, day_count) VALUES (?, ?, ?)") dayStmt, err := tx.Prepare("INSERT INTO daily (model, day_start, day_count) VALUES (?, ?, ?)")
if err != nil { if err != nil {
return coreerr.E("ratelimit.saveState", "prepare daily", err) return core.E("ratelimit.saveState", "prepare daily", err)
} }
defer dayStmt.Close() defer dayStmt.Close()
@ -221,16 +220,16 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
} }
for _, t := range stats.Requests { for _, t := range stats.Requests {
if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil { if _, err := reqStmt.Exec(model, t.UnixNano()); err != nil {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert request %s", model), err) return core.E("ratelimit.saveState", core.Concat("insert request ", model), err)
} }
} }
for _, te := range stats.Tokens { for _, te := range stats.Tokens {
if _, err := tokStmt.Exec(model, te.Time.UnixNano(), te.Count); err != nil { if _, err := tokStmt.Exec(model, te.Time.UnixNano(), te.Count); err != nil {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert token %s", model), err) return core.E("ratelimit.saveState", core.Concat("insert token ", model), err)
} }
} }
if _, err := dayStmt.Exec(model, stats.DayStart.UnixNano(), stats.DayCount); err != nil { if _, err := dayStmt.Exec(model, stats.DayStart.UnixNano(), stats.DayCount); err != nil {
return coreerr.E("ratelimit.saveState", fmt.Sprintf("insert daily %s", model), err) return core.E("ratelimit.saveState", core.Concat("insert daily ", model), err)
} }
} }
return nil return nil
@ -238,7 +237,7 @@ func insertState(tx *sql.Tx, state map[string]*UsageStats) error {
func commitTx(tx *sql.Tx, scope string) error { func commitTx(tx *sql.Tx, scope string) error {
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return coreerr.E(scope, "commit", err) return core.E(scope, "commit", err)
} }
return nil return nil
} }
@ -250,7 +249,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
// Load daily counters first (these define which models have state). // Load daily counters first (these define which models have state).
rows, err := s.db.Query("SELECT model, day_start, day_count FROM daily") rows, err := s.db.Query("SELECT model, day_start, day_count FROM daily")
if err != nil { if err != nil {
return nil, coreerr.E("ratelimit.loadState", "query daily", err) return nil, core.E("ratelimit.loadState", "query daily", err)
} }
defer rows.Close() defer rows.Close()
@ -259,7 +258,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
var dayStartNano int64 var dayStartNano int64
var dayCount int var dayCount int
if err := rows.Scan(&model, &dayStartNano, &dayCount); err != nil { if err := rows.Scan(&model, &dayStartNano, &dayCount); err != nil {
return nil, coreerr.E("ratelimit.loadState", "scan daily", err) return nil, core.E("ratelimit.loadState", "scan daily", err)
} }
result[model] = &UsageStats{ result[model] = &UsageStats{
DayStart: time.Unix(0, dayStartNano), DayStart: time.Unix(0, dayStartNano),
@ -267,13 +266,13 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
} }
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, coreerr.E("ratelimit.loadState", "daily rows", err) return nil, core.E("ratelimit.loadState", "daily rows", err)
} }
// Load requests. // Load requests.
reqRows, err := s.db.Query("SELECT model, ts FROM requests ORDER BY ts") reqRows, err := s.db.Query("SELECT model, ts FROM requests ORDER BY ts")
if err != nil { if err != nil {
return nil, coreerr.E("ratelimit.loadState", "query requests", err) return nil, core.E("ratelimit.loadState", "query requests", err)
} }
defer reqRows.Close() defer reqRows.Close()
@ -281,7 +280,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
var model string var model string
var tsNano int64 var tsNano int64
if err := reqRows.Scan(&model, &tsNano); err != nil { if err := reqRows.Scan(&model, &tsNano); err != nil {
return nil, coreerr.E("ratelimit.loadState", "scan requests", err) return nil, core.E("ratelimit.loadState", "scan requests", err)
} }
if _, ok := result[model]; !ok { if _, ok := result[model]; !ok {
result[model] = &UsageStats{} result[model] = &UsageStats{}
@ -289,13 +288,13 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
result[model].Requests = append(result[model].Requests, time.Unix(0, tsNano)) result[model].Requests = append(result[model].Requests, time.Unix(0, tsNano))
} }
if err := reqRows.Err(); err != nil { if err := reqRows.Err(); err != nil {
return nil, coreerr.E("ratelimit.loadState", "request rows", err) return nil, core.E("ratelimit.loadState", "request rows", err)
} }
// Load tokens. // Load tokens.
tokRows, err := s.db.Query("SELECT model, ts, count FROM tokens ORDER BY ts") tokRows, err := s.db.Query("SELECT model, ts, count FROM tokens ORDER BY ts")
if err != nil { if err != nil {
return nil, coreerr.E("ratelimit.loadState", "query tokens", err) return nil, core.E("ratelimit.loadState", "query tokens", err)
} }
defer tokRows.Close() defer tokRows.Close()
@ -304,7 +303,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
var tsNano int64 var tsNano int64
var count int var count int
if err := tokRows.Scan(&model, &tsNano, &count); err != nil { if err := tokRows.Scan(&model, &tsNano, &count); err != nil {
return nil, coreerr.E("ratelimit.loadState", "scan tokens", err) return nil, core.E("ratelimit.loadState", "scan tokens", err)
} }
if _, ok := result[model]; !ok { if _, ok := result[model]; !ok {
result[model] = &UsageStats{} result[model] = &UsageStats{}
@ -315,7 +314,7 @@ func (s *sqliteStore) loadState() (map[string]*UsageStats, error) {
}) })
} }
if err := tokRows.Err(); err != nil { if err := tokRows.Err(); err != nil {
return nil, coreerr.E("ratelimit.loadState", "token rows", err) return nil, core.E("ratelimit.loadState", "token rows", err)
} }
return result, nil return result, nil