2026-03-27 04:23:34 +00:00
|
|
|
// SPDX-License-Identifier: EUPL-1.2
|
|
|
|
|
|
2026-03-09 08:33:00 +00:00
|
|
|
package ratelimit
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"net/http"
|
|
|
|
|
"net/http/httptest"
|
|
|
|
|
"testing"
|
|
|
|
|
|
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
|
"github.com/stretchr/testify/require"
|
|
|
|
|
)
|
|
|
|
|
|
2026-03-26 18:51:54 +00:00
|
|
|
func TestIter_Iterators_Good(t *testing.T) {
|
2026-03-09 08:33:00 +00:00
|
|
|
rl, err := NewWithConfig(Config{
|
|
|
|
|
Quotas: map[string]ModelQuota{
|
|
|
|
|
"model-c": {MaxRPM: 10},
|
|
|
|
|
"model-a": {MaxRPM: 10},
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
rl.RecordUsage("model-b", 1, 1)
|
|
|
|
|
|
|
|
|
|
t.Run("Models iterator is sorted", func(t *testing.T) {
|
|
|
|
|
var models []string
|
|
|
|
|
for m := range rl.Models() {
|
|
|
|
|
models = append(models, m)
|
|
|
|
|
}
|
|
|
|
|
// Should include Gemini defaults (from NewWithConfig's default) + custom models
|
|
|
|
|
// and be sorted.
|
|
|
|
|
assert.Contains(t, models, "model-a")
|
|
|
|
|
assert.Contains(t, models, "model-b")
|
|
|
|
|
assert.Contains(t, models, "model-c")
|
2026-03-23 07:26:15 +00:00
|
|
|
|
2026-03-09 08:33:00 +00:00
|
|
|
// Check sorting of our specific models
|
|
|
|
|
foundA, foundB, foundC := -1, -1, -1
|
|
|
|
|
for i, m := range models {
|
2026-03-23 07:26:15 +00:00
|
|
|
if m == "model-a" {
|
|
|
|
|
foundA = i
|
|
|
|
|
}
|
|
|
|
|
if m == "model-b" {
|
|
|
|
|
foundB = i
|
|
|
|
|
}
|
|
|
|
|
if m == "model-c" {
|
|
|
|
|
foundC = i
|
|
|
|
|
}
|
2026-03-09 08:33:00 +00:00
|
|
|
}
|
|
|
|
|
assert.True(t, foundA < foundB && foundB < foundC, "models should be sorted: a < b < c")
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("Iter iterator is sorted", func(t *testing.T) {
|
|
|
|
|
var models []string
|
|
|
|
|
for m, stats := range rl.Iter() {
|
|
|
|
|
models = append(models, m)
|
|
|
|
|
if m == "model-a" {
|
|
|
|
|
assert.Equal(t, 10, stats.MaxRPM)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.Contains(t, models, "model-a")
|
|
|
|
|
assert.Contains(t, models, "model-b")
|
|
|
|
|
assert.Contains(t, models, "model-c")
|
|
|
|
|
|
|
|
|
|
// Check sorting
|
|
|
|
|
foundA, foundB, foundC := -1, -1, -1
|
|
|
|
|
for i, m := range models {
|
2026-03-23 07:26:15 +00:00
|
|
|
if m == "model-a" {
|
|
|
|
|
foundA = i
|
|
|
|
|
}
|
|
|
|
|
if m == "model-b" {
|
|
|
|
|
foundB = i
|
|
|
|
|
}
|
|
|
|
|
if m == "model-c" {
|
|
|
|
|
foundC = i
|
|
|
|
|
}
|
2026-03-09 08:33:00 +00:00
|
|
|
}
|
|
|
|
|
assert.True(t, foundA < foundB && foundB < foundC, "iter should be sorted: a < b < c")
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-26 18:51:54 +00:00
|
|
|
func TestIter_IterEarlyBreak_Good(t *testing.T) {
|
2026-03-17 08:56:53 +00:00
|
|
|
rl, err := NewWithConfig(Config{
|
|
|
|
|
Quotas: map[string]ModelQuota{
|
|
|
|
|
"model-a": {MaxRPM: 10},
|
|
|
|
|
"model-b": {MaxRPM: 20},
|
|
|
|
|
"model-c": {MaxRPM: 30},
|
|
|
|
|
},
|
|
|
|
|
})
|
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
|
|
t.Run("Iter breaks early", func(t *testing.T) {
|
|
|
|
|
var count int
|
|
|
|
|
for range rl.Iter() {
|
|
|
|
|
count++
|
|
|
|
|
if count == 1 {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.Equal(t, 1, count, "should stop after first iteration")
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("Models early break via manual iteration", func(t *testing.T) {
|
|
|
|
|
var count int
|
|
|
|
|
for range rl.Models() {
|
|
|
|
|
count++
|
|
|
|
|
if count == 2 {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
assert.Equal(t, 2, count, "should stop after two models")
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-26 18:51:54 +00:00
|
|
|
func TestIter_CountTokensFull_Ugly(t *testing.T) {
|
2026-03-23 07:26:15 +00:00
|
|
|
t.Run("empty model is rejected", func(t *testing.T) {
|
|
|
|
|
_, err := CountTokens(context.Background(), "key", "", "text")
|
2026-03-09 08:33:00 +00:00
|
|
|
assert.Error(t, err)
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("API error non-200", func(t *testing.T) {
|
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
|
|
|
w.Write([]byte("bad request"))
|
|
|
|
|
}))
|
|
|
|
|
defer server.Close()
|
|
|
|
|
|
2026-03-23 07:26:15 +00:00
|
|
|
_, err := countTokensWithClient(context.Background(), server.Client(), server.URL, "key", "model", "text")
|
|
|
|
|
require.Error(t, err)
|
|
|
|
|
assert.Contains(t, err.Error(), "status 400")
|
2026-03-09 08:33:00 +00:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("context cancelled", func(t *testing.T) {
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
|
cancel()
|
2026-03-23 07:26:15 +00:00
|
|
|
_, err := countTokensWithClient(ctx, http.DefaultClient, "https://generativelanguage.googleapis.com", "key", "model", "text")
|
2026-03-09 08:33:00 +00:00
|
|
|
assert.Error(t, err)
|
|
|
|
|
assert.Contains(t, err.Error(), "do request")
|
|
|
|
|
})
|
|
|
|
|
}
|