diff --git a/ratelimit.go b/ratelimit.go index 991b36b..1909642 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -249,7 +249,12 @@ func (rl *RateLimiter) Load() error { return err } - return yaml.Unmarshal([]byte(content), rl) + if err := yaml.Unmarshal([]byte(content), rl); err != nil { + return err + } + + ensureMaps(rl) + return nil } // loadSQLite reads quotas and state from the SQLite backend. @@ -596,6 +601,14 @@ func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { now := time.Now() decision := Decision{} + if estimatedTokens < 0 { + decision.Allowed = false + decision.Code = DecisionInvalidTokens + decision.Reason = "estimated tokens must be non-negative" + decision.Stats = rl.snapshotLocked(model) + return decision + } + quota, ok := rl.Quotas[model] if !ok { decision.Allowed = true @@ -606,13 +619,6 @@ func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { } if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 { - if estimatedTokens < 0 { - decision.Allowed = false - decision.Code = DecisionInvalidTokens - decision.Reason = "estimated tokens must be non-negative" - decision.Stats = rl.snapshotLocked(model) - return decision - } decision.Allowed = true decision.Code = DecisionUnlimited decision.Reason = "all limits are unlimited" @@ -627,14 +633,6 @@ func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { rl.State[model] = stats } - if estimatedTokens < 0 { - decision.Allowed = false - decision.Code = DecisionInvalidTokens - decision.Reason = "estimated tokens must be non-negative" - decision.Stats = rl.snapshotLocked(model) - return decision - } - decision.Stats = rl.snapshotLocked(model) if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD { @@ -836,6 +834,15 @@ func newConfiguredRateLimiter(cfg Config) *RateLimiter { return rl } +func ensureMaps(rl *RateLimiter) { + if rl.Quotas == nil { + rl.Quotas = make(map[string]ModelQuota) + } + if rl.State == nil { + rl.State = make(map[string]*UsageStats) + } +} + func applyConfig(rl *RateLimiter, cfg Config) { profiles := DefaultProfiles() providers := cfg.Providers