feat(allowance): add Redis backend for AllowanceStore
Implements RedisStore using github.com/redis/go-redis/v9 with Lua scripts for atomic increment/decrement operations. Adds functional options (WithRedisPassword, WithRedisDB, WithRedisPrefix), config wiring via "redis" store_backend, and comprehensive tests covering all 10 interface methods, concurrency, persistence, and error paths. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
46e3617f9c
commit
0be744e46a
6 changed files with 844 additions and 2 deletions
364
allowance_redis.go
Normal file
364
allowance_redis.go
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
package agentic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisStore implements AllowanceStore using Redis as the backing store.
|
||||
// It provides persistent, network-accessible storage suitable for multi-node
|
||||
// deployments where agents share quota state.
|
||||
type RedisStore struct {
|
||||
client *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
// redisConfig holds the configuration for a RedisStore.
|
||||
type redisConfig struct {
|
||||
password string
|
||||
db int
|
||||
prefix string
|
||||
}
|
||||
|
||||
// RedisOption is a functional option for configuring a RedisStore.
|
||||
type RedisOption func(*redisConfig)
|
||||
|
||||
// WithRedisPassword sets the password for authenticating with Redis.
|
||||
func WithRedisPassword(pw string) RedisOption {
|
||||
return func(c *redisConfig) {
|
||||
c.password = pw
|
||||
}
|
||||
}
|
||||
|
||||
// WithRedisDB selects the Redis database number.
|
||||
func WithRedisDB(db int) RedisOption {
|
||||
return func(c *redisConfig) {
|
||||
c.db = db
|
||||
}
|
||||
}
|
||||
|
||||
// WithRedisPrefix sets the key prefix for all Redis keys. Default: "agentic".
|
||||
func WithRedisPrefix(prefix string) RedisOption {
|
||||
return func(c *redisConfig) {
|
||||
c.prefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// NewRedisStore creates a new Redis-backed allowance store connecting to the
|
||||
// given address (host:port). It pings the server to verify connectivity.
|
||||
func NewRedisStore(addr string, opts ...RedisOption) (*RedisStore, error) {
|
||||
cfg := &redisConfig{
|
||||
prefix: "agentic",
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(cfg)
|
||||
}
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: addr,
|
||||
Password: cfg.password,
|
||||
DB: cfg.db,
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()}
|
||||
}
|
||||
|
||||
return &RedisStore{
|
||||
client: client,
|
||||
prefix: cfg.prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close releases the underlying Redis connection.
|
||||
func (r *RedisStore) Close() error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// --- key helpers ---
|
||||
|
||||
func (r *RedisStore) allowanceKey(agentID string) string {
|
||||
return r.prefix + ":allowance:" + agentID
|
||||
}
|
||||
|
||||
func (r *RedisStore) usageKey(agentID string) string {
|
||||
return r.prefix + ":usage:" + agentID
|
||||
}
|
||||
|
||||
func (r *RedisStore) modelQuotaKey(model string) string {
|
||||
return r.prefix + ":model_quota:" + model
|
||||
}
|
||||
|
||||
func (r *RedisStore) modelUsageKey(model string) string {
|
||||
return r.prefix + ":model_usage:" + model
|
||||
}
|
||||
|
||||
// --- AllowanceStore interface ---
|
||||
|
||||
// GetAllowance returns the quota limits for an agent.
|
||||
func (r *RedisStore) GetAllowance(agentID string) (*AgentAllowance, error) {
|
||||
ctx := context.Background()
|
||||
val, err := r.client.Get(ctx, r.allowanceKey(agentID)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID}
|
||||
}
|
||||
return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()}
|
||||
}
|
||||
var aj allowanceJSON
|
||||
if err := json.Unmarshal([]byte(val), &aj); err != nil {
|
||||
return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()}
|
||||
}
|
||||
return aj.toAgentAllowance(), nil
|
||||
}
|
||||
|
||||
// SetAllowance persists quota limits for an agent.
|
||||
func (r *RedisStore) SetAllowance(a *AgentAllowance) error {
|
||||
ctx := context.Background()
|
||||
aj := newAllowanceJSON(a)
|
||||
data, err := json.Marshal(aj)
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()}
|
||||
}
|
||||
if err := r.client.Set(ctx, r.allowanceKey(a.AgentID), data, 0).Err(); err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUsage returns the current usage record for an agent.
|
||||
func (r *RedisStore) GetUsage(agentID string) (*UsageRecord, error) {
|
||||
ctx := context.Background()
|
||||
val, err := r.client.Get(ctx, r.usageKey(agentID)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return &UsageRecord{
|
||||
AgentID: agentID,
|
||||
PeriodStart: startOfDay(time.Now().UTC()),
|
||||
}, nil
|
||||
}
|
||||
return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()}
|
||||
}
|
||||
var u UsageRecord
|
||||
if err := json.Unmarshal([]byte(val), &u); err != nil {
|
||||
return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()}
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// incrementUsageLua atomically reads, increments, and writes back a usage record.
|
||||
// KEYS[1] = usage key
|
||||
// ARGV[1] = tokens to add (int64)
|
||||
// ARGV[2] = jobs to add (int)
|
||||
// ARGV[3] = agent ID (for creating a new record)
|
||||
// ARGV[4] = period start ISO string (for creating a new record)
|
||||
var incrementUsageLua = redis.NewScript(`
|
||||
local val = redis.call('GET', KEYS[1])
|
||||
local u
|
||||
if val then
|
||||
u = cjson.decode(val)
|
||||
else
|
||||
u = {
|
||||
agent_id = ARGV[3],
|
||||
tokens_used = 0,
|
||||
jobs_started = 0,
|
||||
active_jobs = 0,
|
||||
period_start = ARGV[4]
|
||||
}
|
||||
end
|
||||
u.tokens_used = u.tokens_used + tonumber(ARGV[1])
|
||||
u.jobs_started = u.jobs_started + tonumber(ARGV[2])
|
||||
if tonumber(ARGV[2]) > 0 then
|
||||
u.active_jobs = u.active_jobs + tonumber(ARGV[2])
|
||||
end
|
||||
redis.call('SET', KEYS[1], cjson.encode(u))
|
||||
return 'OK'
|
||||
`)
|
||||
|
||||
// IncrementUsage atomically adds to an agent's usage counters.
|
||||
func (r *RedisStore) IncrementUsage(agentID string, tokens int64, jobs int) error {
|
||||
ctx := context.Background()
|
||||
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
|
||||
err := incrementUsageLua.Run(ctx, r.client,
|
||||
[]string{r.usageKey(agentID)},
|
||||
tokens, jobs, agentID, periodStart,
|
||||
).Err()
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to increment usage: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decrementActiveJobsLua atomically decrements the active jobs counter, flooring at zero.
|
||||
// KEYS[1] = usage key
|
||||
// ARGV[1] = agent ID
|
||||
// ARGV[2] = period start ISO string
|
||||
var decrementActiveJobsLua = redis.NewScript(`
|
||||
local val = redis.call('GET', KEYS[1])
|
||||
if not val then
|
||||
return 'OK'
|
||||
end
|
||||
local u = cjson.decode(val)
|
||||
if u.active_jobs and u.active_jobs > 0 then
|
||||
u.active_jobs = u.active_jobs - 1
|
||||
end
|
||||
redis.call('SET', KEYS[1], cjson.encode(u))
|
||||
return 'OK'
|
||||
`)
|
||||
|
||||
// DecrementActiveJobs reduces the active job count by 1.
|
||||
func (r *RedisStore) DecrementActiveJobs(agentID string) error {
|
||||
ctx := context.Background()
|
||||
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
|
||||
err := decrementActiveJobsLua.Run(ctx, r.client,
|
||||
[]string{r.usageKey(agentID)},
|
||||
agentID, periodStart,
|
||||
).Err()
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to decrement active jobs: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// returnTokensLua atomically subtracts tokens from usage, flooring at zero.
|
||||
// KEYS[1] = usage key
|
||||
// ARGV[1] = tokens to return (int64)
|
||||
// ARGV[2] = agent ID
|
||||
// ARGV[3] = period start ISO string
|
||||
var returnTokensLua = redis.NewScript(`
|
||||
local val = redis.call('GET', KEYS[1])
|
||||
if not val then
|
||||
return 'OK'
|
||||
end
|
||||
local u = cjson.decode(val)
|
||||
u.tokens_used = u.tokens_used - tonumber(ARGV[1])
|
||||
if u.tokens_used < 0 then
|
||||
u.tokens_used = 0
|
||||
end
|
||||
redis.call('SET', KEYS[1], cjson.encode(u))
|
||||
return 'OK'
|
||||
`)
|
||||
|
||||
// ReturnTokens adds tokens back to the agent's remaining quota.
|
||||
func (r *RedisStore) ReturnTokens(agentID string, tokens int64) error {
|
||||
ctx := context.Background()
|
||||
periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339)
|
||||
err := returnTokensLua.Run(ctx, r.client,
|
||||
[]string{r.usageKey(agentID)},
|
||||
tokens, agentID, periodStart,
|
||||
).Err()
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to return tokens: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetUsage clears usage counters for an agent (daily reset).
|
||||
func (r *RedisStore) ResetUsage(agentID string) error {
|
||||
ctx := context.Background()
|
||||
u := &UsageRecord{
|
||||
AgentID: agentID,
|
||||
PeriodStart: startOfDay(time.Now().UTC()),
|
||||
}
|
||||
data, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()}
|
||||
}
|
||||
if err := r.client.Set(ctx, r.usageKey(agentID), data, 0).Err(); err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to reset usage: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelQuota returns global limits for a model.
|
||||
func (r *RedisStore) GetModelQuota(model string) (*ModelQuota, error) {
|
||||
ctx := context.Background()
|
||||
val, err := r.client.Get(ctx, r.modelQuotaKey(model)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, &APIError{Code: 404, Message: "model quota not found: " + model}
|
||||
}
|
||||
return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()}
|
||||
}
|
||||
var q ModelQuota
|
||||
if err := json.Unmarshal([]byte(val), &q); err != nil {
|
||||
return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()}
|
||||
}
|
||||
return &q, nil
|
||||
}
|
||||
|
||||
// GetModelUsage returns current token usage for a model.
|
||||
func (r *RedisStore) GetModelUsage(model string) (int64, error) {
|
||||
ctx := context.Background()
|
||||
val, err := r.client.Get(ctx, r.modelUsageKey(model)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()}
|
||||
}
|
||||
var tokens int64
|
||||
if err := json.Unmarshal([]byte(val), &tokens); err != nil {
|
||||
return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()}
|
||||
}
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
// incrementModelUsageLua atomically increments the model usage counter.
|
||||
// KEYS[1] = model usage key
|
||||
// ARGV[1] = tokens to add
|
||||
var incrementModelUsageLua = redis.NewScript(`
|
||||
local val = redis.call('GET', KEYS[1])
|
||||
local current = 0
|
||||
if val then
|
||||
current = tonumber(val)
|
||||
end
|
||||
current = current + tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], tostring(current))
|
||||
return current
|
||||
`)
|
||||
|
||||
// IncrementModelUsage atomically adds to a model's usage counter.
|
||||
func (r *RedisStore) IncrementModelUsage(model string, tokens int64) error {
|
||||
ctx := context.Background()
|
||||
err := incrementModelUsageLua.Run(ctx, r.client,
|
||||
[]string{r.modelUsageKey(model)},
|
||||
tokens,
|
||||
).Err()
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to increment model usage: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetModelQuota persists global limits for a model.
|
||||
func (r *RedisStore) SetModelQuota(q *ModelQuota) error {
|
||||
ctx := context.Background()
|
||||
data, err := json.Marshal(q)
|
||||
if err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()}
|
||||
}
|
||||
if err := r.client.Set(ctx, r.modelQuotaKey(q.Model), data, 0).Err(); err != nil {
|
||||
return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlushPrefix deletes all keys matching the store's prefix. Useful for testing cleanup.
|
||||
func (r *RedisStore) FlushPrefix(ctx context.Context) error {
|
||||
iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator()
|
||||
for iter.Next(ctx) {
|
||||
if err := r.client.Del(ctx, iter.Val()).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return iter.Err()
|
||||
}
|
||||
454
allowance_redis_test.go
Normal file
454
allowance_redis_test.go
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
package agentic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testRedisAddr = "10.69.69.87:6379"
|
||||
|
||||
// newTestRedisStore creates a RedisStore with a unique prefix for test isolation.
|
||||
// Skips the test if Redis is unreachable.
|
||||
func newTestRedisStore(t *testing.T) *RedisStore {
|
||||
t.Helper()
|
||||
prefix := fmt.Sprintf("test_%d", time.Now().UnixNano())
|
||||
s, err := NewRedisStore(testRedisAddr, WithRedisPrefix(prefix))
|
||||
if err != nil {
|
||||
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
ctx := context.Background()
|
||||
_ = s.FlushPrefix(ctx)
|
||||
_ = s.Close()
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// --- SetAllowance / GetAllowance ---
|
||||
|
||||
func TestRedisStore_SetGetAllowance_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
a := &AgentAllowance{
|
||||
AgentID: "agent-1",
|
||||
DailyTokenLimit: 100000,
|
||||
DailyJobLimit: 10,
|
||||
ConcurrentJobs: 2,
|
||||
MaxJobDuration: 30 * time.Minute,
|
||||
ModelAllowlist: []string{"claude-sonnet-4-5-20250929"},
|
||||
}
|
||||
|
||||
err := s.SetAllowance(a)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetAllowance("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, a.AgentID, got.AgentID)
|
||||
assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit)
|
||||
assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit)
|
||||
assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs)
|
||||
assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration)
|
||||
assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist)
|
||||
}
|
||||
|
||||
func TestRedisStore_GetAllowance_Bad_NotFound(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
_, err := s.GetAllowance("nonexistent")
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*APIError)
|
||||
require.True(t, ok, "expected *APIError")
|
||||
assert.Equal(t, 404, apiErr.Code)
|
||||
assert.Contains(t, err.Error(), "allowance not found")
|
||||
}
|
||||
|
||||
func TestRedisStore_SetAllowance_Good_Overwrite(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100})
|
||||
_ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200})
|
||||
|
||||
got, err := s.GetAllowance("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(200), got.DailyTokenLimit)
|
||||
}
|
||||
|
||||
// --- GetUsage / IncrementUsage ---
|
||||
|
||||
func TestRedisStore_GetUsage_Good_Default(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "agent-1", u.AgentID)
|
||||
assert.Equal(t, int64(0), u.TokensUsed)
|
||||
assert.Equal(t, 0, u.JobsStarted)
|
||||
assert.Equal(t, 0, u.ActiveJobs)
|
||||
}
|
||||
|
||||
func TestRedisStore_IncrementUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
err := s.IncrementUsage("agent-1", 5000, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(5000), u.TokensUsed)
|
||||
assert.Equal(t, 1, u.JobsStarted)
|
||||
assert.Equal(t, 1, u.ActiveJobs)
|
||||
}
|
||||
|
||||
func TestRedisStore_IncrementUsage_Good_Accumulates(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementUsage("agent-1", 1000, 1)
|
||||
_ = s.IncrementUsage("agent-1", 2000, 1)
|
||||
_ = s.IncrementUsage("agent-1", 3000, 0)
|
||||
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(6000), u.TokensUsed)
|
||||
assert.Equal(t, 2, u.JobsStarted)
|
||||
assert.Equal(t, 2, u.ActiveJobs)
|
||||
}
|
||||
|
||||
// --- DecrementActiveJobs ---
|
||||
|
||||
func TestRedisStore_DecrementActiveJobs_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementUsage("agent-1", 0, 2)
|
||||
_ = s.DecrementActiveJobs("agent-1")
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, 1, u.ActiveJobs)
|
||||
}
|
||||
|
||||
func TestRedisStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.DecrementActiveJobs("agent-1") // no usage record yet
|
||||
_ = s.IncrementUsage("agent-1", 0, 0)
|
||||
_ = s.DecrementActiveJobs("agent-1") // should stay at 0
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, 0, u.ActiveJobs)
|
||||
}
|
||||
|
||||
// --- ReturnTokens ---
|
||||
|
||||
func TestRedisStore_ReturnTokens_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementUsage("agent-1", 10000, 0)
|
||||
err := s.ReturnTokens("agent-1", 5000)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, int64(5000), u.TokensUsed)
|
||||
}
|
||||
|
||||
func TestRedisStore_ReturnTokens_Good_FloorAtZero(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementUsage("agent-1", 1000, 0)
|
||||
_ = s.ReturnTokens("agent-1", 5000) // more than used
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, int64(0), u.TokensUsed)
|
||||
}
|
||||
|
||||
func TestRedisStore_ReturnTokens_Good_NoRecord(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
// Return tokens for agent with no usage record -- should be a no-op
|
||||
err := s.ReturnTokens("agent-1", 500)
|
||||
require.NoError(t, err)
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, int64(0), u.TokensUsed)
|
||||
}
|
||||
|
||||
// --- ResetUsage ---
|
||||
|
||||
func TestRedisStore_ResetUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementUsage("agent-1", 50000, 5)
|
||||
|
||||
err := s.ResetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, int64(0), u.TokensUsed)
|
||||
assert.Equal(t, 0, u.JobsStarted)
|
||||
assert.Equal(t, 0, u.ActiveJobs)
|
||||
}
|
||||
|
||||
// --- ModelQuota ---
|
||||
|
||||
func TestRedisStore_GetModelQuota_Bad_NotFound(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
_, err := s.GetModelQuota("nonexistent")
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*APIError)
|
||||
require.True(t, ok, "expected *APIError")
|
||||
assert.Equal(t, 404, apiErr.Code)
|
||||
assert.Contains(t, err.Error(), "model quota not found")
|
||||
}
|
||||
|
||||
func TestRedisStore_SetGetModelQuota_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
q := &ModelQuota{
|
||||
Model: "claude-opus-4-6",
|
||||
DailyTokenBudget: 500000,
|
||||
HourlyRateLimit: 100,
|
||||
CostCeiling: 10000,
|
||||
}
|
||||
err := s.SetModelQuota(q)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.GetModelQuota("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, q.Model, got.Model)
|
||||
assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget)
|
||||
assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit)
|
||||
assert.Equal(t, q.CostCeiling, got.CostCeiling)
|
||||
}
|
||||
|
||||
// --- ModelUsage ---
|
||||
|
||||
func TestRedisStore_ModelUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.IncrementModelUsage("claude-sonnet", 10000)
|
||||
_ = s.IncrementModelUsage("claude-sonnet", 5000)
|
||||
|
||||
usage, err := s.GetModelUsage("claude-sonnet")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(15000), usage)
|
||||
}
|
||||
|
||||
func TestRedisStore_GetModelUsage_Good_Default(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
usage, err := s.GetModelUsage("unknown-model")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), usage)
|
||||
}
|
||||
|
||||
// --- Persistence: set, get, verify ---
|
||||
|
||||
func TestRedisStore_Persistence_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.SetAllowance(&AgentAllowance{
|
||||
AgentID: "agent-1",
|
||||
DailyTokenLimit: 100000,
|
||||
MaxJobDuration: 15 * time.Minute,
|
||||
})
|
||||
_ = s.IncrementUsage("agent-1", 25000, 3)
|
||||
_ = s.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000})
|
||||
_ = s.IncrementModelUsage("claude-opus-4-6", 42000)
|
||||
|
||||
// Verify all data persists (same connection, but data is in Redis)
|
||||
a, err := s.GetAllowance("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(100000), a.DailyTokenLimit)
|
||||
assert.Equal(t, 15*time.Minute, a.MaxJobDuration)
|
||||
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(25000), u.TokensUsed)
|
||||
assert.Equal(t, 3, u.JobsStarted)
|
||||
assert.Equal(t, 3, u.ActiveJobs)
|
||||
|
||||
q, err := s.GetModelQuota("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(500000), q.DailyTokenBudget)
|
||||
|
||||
mu, err := s.GetModelUsage("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(42000), mu)
|
||||
}
|
||||
|
||||
// --- Concurrent access ---
|
||||
|
||||
func TestRedisStore_ConcurrentIncrementUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
const goroutines = 10
|
||||
const tokensEach = 1000
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.IncrementUsage("agent-1", tokensEach, 1)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed)
|
||||
assert.Equal(t, goroutines, u.JobsStarted)
|
||||
assert.Equal(t, goroutines, u.ActiveJobs)
|
||||
}
|
||||
|
||||
func TestRedisStore_ConcurrentModelUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
const goroutines = 10
|
||||
const tokensEach int64 = 500
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.IncrementModelUsage("claude-opus-4-6", tokensEach)
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
usage, err := s.GetModelUsage("claude-opus-4-6")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, goroutines*tokensEach, usage)
|
||||
}
|
||||
|
||||
func TestRedisStore_ConcurrentMixed_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
|
||||
_ = s.SetAllowance(&AgentAllowance{
|
||||
AgentID: "agent-1",
|
||||
DailyTokenLimit: 1000000,
|
||||
DailyJobLimit: 100,
|
||||
ConcurrentJobs: 50,
|
||||
})
|
||||
|
||||
const goroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines * 3)
|
||||
|
||||
// Increment usage
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = s.IncrementUsage("agent-1", 100, 1)
|
||||
}()
|
||||
}
|
||||
|
||||
// Decrement active jobs
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = s.DecrementActiveJobs("agent-1")
|
||||
}()
|
||||
}
|
||||
|
||||
// Return tokens
|
||||
for i := 0; i < goroutines; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = s.ReturnTokens("agent-1", 10)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify no panics and data is consistent
|
||||
u, err := s.GetUsage("agent-1")
|
||||
require.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, u.TokensUsed, int64(0))
|
||||
assert.GreaterOrEqual(t, u.ActiveJobs, 0)
|
||||
}
|
||||
|
||||
// --- AllowanceService integration via RedisStore ---
|
||||
|
||||
func TestRedisStore_AllowanceServiceCheck_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
svc := NewAllowanceService(s)
|
||||
|
||||
_ = s.SetAllowance(&AgentAllowance{
|
||||
AgentID: "agent-1",
|
||||
DailyTokenLimit: 100000,
|
||||
DailyJobLimit: 10,
|
||||
ConcurrentJobs: 2,
|
||||
})
|
||||
|
||||
result, err := svc.Check("agent-1", "")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, result.Allowed)
|
||||
assert.Equal(t, AllowanceOK, result.Status)
|
||||
}
|
||||
|
||||
func TestRedisStore_AllowanceServiceRecordUsage_Good(t *testing.T) {
|
||||
s := newTestRedisStore(t)
|
||||
svc := NewAllowanceService(s)
|
||||
|
||||
_ = s.SetAllowance(&AgentAllowance{
|
||||
AgentID: "agent-1",
|
||||
DailyTokenLimit: 100000,
|
||||
})
|
||||
|
||||
// Start job
|
||||
err := svc.RecordUsage(UsageReport{
|
||||
AgentID: "agent-1",
|
||||
JobID: "job-1",
|
||||
Event: QuotaEventJobStarted,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Complete job
|
||||
err = svc.RecordUsage(UsageReport{
|
||||
AgentID: "agent-1",
|
||||
JobID: "job-1",
|
||||
Model: "claude-sonnet",
|
||||
TokensIn: 1000,
|
||||
TokensOut: 500,
|
||||
Event: QuotaEventJobCompleted,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
u, _ := s.GetUsage("agent-1")
|
||||
assert.Equal(t, int64(1500), u.TokensUsed)
|
||||
assert.Equal(t, 0, u.ActiveJobs)
|
||||
}
|
||||
|
||||
// --- Config-based factory with redis backend ---
|
||||
|
||||
func TestNewAllowanceStoreFromConfig_Good_Redis(t *testing.T) {
|
||||
cfg := AllowanceConfig{
|
||||
StoreBackend: "redis",
|
||||
RedisAddr: testRedisAddr,
|
||||
}
|
||||
s, err := NewAllowanceStoreFromConfig(cfg)
|
||||
if err != nil {
|
||||
t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err)
|
||||
}
|
||||
rs, ok := s.(*RedisStore)
|
||||
assert.True(t, ok, "expected RedisStore")
|
||||
_ = rs.Close()
|
||||
}
|
||||
|
||||
// --- Constructor error case ---
|
||||
|
||||
func TestNewRedisStore_Bad_Unreachable(t *testing.T) {
|
||||
_, err := NewRedisStore("127.0.0.1:1") // almost certainly unreachable
|
||||
require.Error(t, err)
|
||||
apiErr, ok := err.(*APIError)
|
||||
require.True(t, ok, "expected *APIError")
|
||||
assert.Equal(t, 500, apiErr.Code)
|
||||
assert.Contains(t, err.Error(), "failed to connect to Redis")
|
||||
}
|
||||
|
|
@ -451,7 +451,7 @@ func TestNewAllowanceStoreFromConfig_Good_SQLite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewAllowanceStoreFromConfig_Bad_UnknownBackend(t *testing.T) {
|
||||
cfg := AllowanceConfig{StoreBackend: "redis"}
|
||||
cfg := AllowanceConfig{StoreBackend: "cassandra"}
|
||||
_, err := NewAllowanceStoreFromConfig(cfg)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "unsupported store backend")
|
||||
|
|
|
|||
|
|
@ -198,11 +198,13 @@ func ConfigPath() (string, error) {
|
|||
|
||||
// AllowanceConfig controls allowance store backend selection.
|
||||
type AllowanceConfig struct {
|
||||
// StoreBackend is the storage backend: "memory" or "sqlite". Default: "memory".
|
||||
// StoreBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory".
|
||||
StoreBackend string `yaml:"store_backend" json:"store_backend"`
|
||||
// StorePath is the file path for the SQLite database.
|
||||
// Default: ~/.config/agentic/allowance.db (only used when StoreBackend is "sqlite").
|
||||
StorePath string `yaml:"store_path" json:"store_path"`
|
||||
// RedisAddr is the host:port for the Redis server (only used when StoreBackend is "redis").
|
||||
RedisAddr string `yaml:"redis_addr" json:"redis_addr"`
|
||||
}
|
||||
|
||||
// DefaultAllowanceStorePath returns the default SQLite path for allowance data.
|
||||
|
|
@ -230,6 +232,8 @@ func NewAllowanceStoreFromConfig(cfg AllowanceConfig) (AllowanceStore, error) {
|
|||
}
|
||||
}
|
||||
return NewSQLiteStore(dbPath)
|
||||
case "redis":
|
||||
return NewRedisStore(cfg.RedisAddr)
|
||||
default:
|
||||
return nil, &APIError{
|
||||
Code: 400,
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -5,12 +5,15 @@ go 1.25.5
|
|||
require (
|
||||
forge.lthn.ai/core/go v0.0.0
|
||||
forge.lthn.ai/core/go-store v0.0.0-00010101000000-000000000000
|
||||
github.com/redis/go-redis/v9 v9.18.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
|
|
@ -18,6 +21,7 @@ require (
|
|||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
modernc.org/libc v1.67.7 // indirect
|
||||
|
|
|
|||
16
go.sum
16
go.sum
|
|
@ -1,6 +1,14 @@
|
|||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
|
|
@ -9,6 +17,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
|
|
@ -20,6 +30,8 @@ github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJm
|
|||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs=
|
||||
github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||
|
|
@ -27,6 +39,10 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t
|
|||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||
golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o=
|
||||
golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA=
|
||||
golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue