diff --git a/ratelimit.go b/ratelimit.go index 9e7eee5..991b36b 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -590,15 +590,6 @@ func (rl *RateLimiter) AllStats() map[string]ModelStats { // Decide returns structured allow/deny information for an estimated request. // It never records usage; call RecordUsage after a successful decision. func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { - if estimatedTokens < 0 { - return Decision{ - Allowed: false, - Code: DecisionInvalidTokens, - Reason: "estimated tokens must be non-negative", - Stats: rl.Stats(model), - } - } - rl.mu.Lock() defer rl.mu.Unlock() @@ -615,6 +606,13 @@ func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { } if quota.MaxRPM == 0 && quota.MaxTPM == 0 && quota.MaxRPD == 0 { + if estimatedTokens < 0 { + decision.Allowed = false + decision.Code = DecisionInvalidTokens + decision.Reason = "estimated tokens must be non-negative" + decision.Stats = rl.snapshotLocked(model) + return decision + } decision.Allowed = true decision.Code = DecisionUnlimited decision.Reason = "all limits are unlimited" @@ -629,6 +627,14 @@ func (rl *RateLimiter) Decide(model string, estimatedTokens int) Decision { rl.State[model] = stats } + if estimatedTokens < 0 { + decision.Allowed = false + decision.Code = DecisionInvalidTokens + decision.Reason = "estimated tokens must be non-negative" + decision.Stats = rl.snapshotLocked(model) + return decision + } + decision.Stats = rl.snapshotLocked(model) if quota.MaxRPD > 0 && stats.DayCount >= quota.MaxRPD { diff --git a/ratelimit_test.go b/ratelimit_test.go index 6dbb65c..3fde6fd 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -363,12 +363,17 @@ func TestRatelimit_Decide_Good(t *testing.T) { t.Run("negative estimate returns invalid decision", func(t *testing.T) { rl := newTestLimiter(t) + model := "neg" + rl.Quotas[model] = ModelQuota{MaxRPM: 5, MaxTPM: 50, MaxRPD: 5} - decision := rl.Decide("neg", -5) + decision := rl.Decide(model, -5) assert.False(t, decision.Allowed) assert.Equal(t, DecisionInvalidTokens, decision.Code) assert.Zero(t, decision.RetryAfter) + require.Contains(t, rl.State, model) + require.NotNil(t, rl.State[model]) + assert.Equal(t, 0, rl.State[model].DayCount) }) }