diff --git a/allowance_redis.go b/allowance_redis.go new file mode 100644 index 0000000..6c84ddc --- /dev/null +++ b/allowance_redis.go @@ -0,0 +1,364 @@ +package agentic + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +// RedisStore implements AllowanceStore using Redis as the backing store. +// It provides persistent, network-accessible storage suitable for multi-node +// deployments where agents share quota state. +type RedisStore struct { + client *redis.Client + prefix string +} + +// redisConfig holds the configuration for a RedisStore. +type redisConfig struct { + password string + db int + prefix string +} + +// RedisOption is a functional option for configuring a RedisStore. +type RedisOption func(*redisConfig) + +// WithRedisPassword sets the password for authenticating with Redis. +func WithRedisPassword(pw string) RedisOption { + return func(c *redisConfig) { + c.password = pw + } +} + +// WithRedisDB selects the Redis database number. +func WithRedisDB(db int) RedisOption { + return func(c *redisConfig) { + c.db = db + } +} + +// WithRedisPrefix sets the key prefix for all Redis keys. Default: "agentic". +func WithRedisPrefix(prefix string) RedisOption { + return func(c *redisConfig) { + c.prefix = prefix + } +} + +// NewRedisStore creates a new Redis-backed allowance store connecting to the +// given address (host:port). It pings the server to verify connectivity. +func NewRedisStore(addr string, opts ...RedisOption) (*RedisStore, error) { + cfg := &redisConfig{ + prefix: "agentic", + } + for _, opt := range opts { + opt(cfg) + } + + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: cfg.password, + DB: cfg.db, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + _ = client.Close() + return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()} + } + + return &RedisStore{ + client: client, + prefix: cfg.prefix, + }, nil +} + +// Close releases the underlying Redis connection. +func (r *RedisStore) Close() error { + return r.client.Close() +} + +// --- key helpers --- + +func (r *RedisStore) allowanceKey(agentID string) string { + return r.prefix + ":allowance:" + agentID +} + +func (r *RedisStore) usageKey(agentID string) string { + return r.prefix + ":usage:" + agentID +} + +func (r *RedisStore) modelQuotaKey(model string) string { + return r.prefix + ":model_quota:" + model +} + +func (r *RedisStore) modelUsageKey(model string) string { + return r.prefix + ":model_usage:" + model +} + +// --- AllowanceStore interface --- + +// GetAllowance returns the quota limits for an agent. +func (r *RedisStore) GetAllowance(agentID string) (*AgentAllowance, error) { + ctx := context.Background() + val, err := r.client.Get(ctx, r.allowanceKey(agentID)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, &APIError{Code: 404, Message: "allowance not found for agent: " + agentID} + } + return nil, &APIError{Code: 500, Message: "failed to get allowance: " + err.Error()} + } + var aj allowanceJSON + if err := json.Unmarshal([]byte(val), &aj); err != nil { + return nil, &APIError{Code: 500, Message: "failed to unmarshal allowance: " + err.Error()} + } + return aj.toAgentAllowance(), nil +} + +// SetAllowance persists quota limits for an agent. +func (r *RedisStore) SetAllowance(a *AgentAllowance) error { + ctx := context.Background() + aj := newAllowanceJSON(a) + data, err := json.Marshal(aj) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal allowance: " + err.Error()} + } + if err := r.client.Set(ctx, r.allowanceKey(a.AgentID), data, 0).Err(); err != nil { + return &APIError{Code: 500, Message: "failed to set allowance: " + err.Error()} + } + return nil +} + +// GetUsage returns the current usage record for an agent. +func (r *RedisStore) GetUsage(agentID string) (*UsageRecord, error) { + ctx := context.Background() + val, err := r.client.Get(ctx, r.usageKey(agentID)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return &UsageRecord{ + AgentID: agentID, + PeriodStart: startOfDay(time.Now().UTC()), + }, nil + } + return nil, &APIError{Code: 500, Message: "failed to get usage: " + err.Error()} + } + var u UsageRecord + if err := json.Unmarshal([]byte(val), &u); err != nil { + return nil, &APIError{Code: 500, Message: "failed to unmarshal usage: " + err.Error()} + } + return &u, nil +} + +// incrementUsageLua atomically reads, increments, and writes back a usage record. +// KEYS[1] = usage key +// ARGV[1] = tokens to add (int64) +// ARGV[2] = jobs to add (int) +// ARGV[3] = agent ID (for creating a new record) +// ARGV[4] = period start ISO string (for creating a new record) +var incrementUsageLua = redis.NewScript(` +local val = redis.call('GET', KEYS[1]) +local u +if val then + u = cjson.decode(val) +else + u = { + agent_id = ARGV[3], + tokens_used = 0, + jobs_started = 0, + active_jobs = 0, + period_start = ARGV[4] + } +end +u.tokens_used = u.tokens_used + tonumber(ARGV[1]) +u.jobs_started = u.jobs_started + tonumber(ARGV[2]) +if tonumber(ARGV[2]) > 0 then + u.active_jobs = u.active_jobs + tonumber(ARGV[2]) +end +redis.call('SET', KEYS[1], cjson.encode(u)) +return 'OK' +`) + +// IncrementUsage atomically adds to an agent's usage counters. +func (r *RedisStore) IncrementUsage(agentID string, tokens int64, jobs int) error { + ctx := context.Background() + periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) + err := incrementUsageLua.Run(ctx, r.client, + []string{r.usageKey(agentID)}, + tokens, jobs, agentID, periodStart, + ).Err() + if err != nil { + return &APIError{Code: 500, Message: "failed to increment usage: " + err.Error()} + } + return nil +} + +// decrementActiveJobsLua atomically decrements the active jobs counter, flooring at zero. +// KEYS[1] = usage key +// ARGV[1] = agent ID +// ARGV[2] = period start ISO string +var decrementActiveJobsLua = redis.NewScript(` +local val = redis.call('GET', KEYS[1]) +if not val then + return 'OK' +end +local u = cjson.decode(val) +if u.active_jobs and u.active_jobs > 0 then + u.active_jobs = u.active_jobs - 1 +end +redis.call('SET', KEYS[1], cjson.encode(u)) +return 'OK' +`) + +// DecrementActiveJobs reduces the active job count by 1. +func (r *RedisStore) DecrementActiveJobs(agentID string) error { + ctx := context.Background() + periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) + err := decrementActiveJobsLua.Run(ctx, r.client, + []string{r.usageKey(agentID)}, + agentID, periodStart, + ).Err() + if err != nil { + return &APIError{Code: 500, Message: "failed to decrement active jobs: " + err.Error()} + } + return nil +} + +// returnTokensLua atomically subtracts tokens from usage, flooring at zero. +// KEYS[1] = usage key +// ARGV[1] = tokens to return (int64) +// ARGV[2] = agent ID +// ARGV[3] = period start ISO string +var returnTokensLua = redis.NewScript(` +local val = redis.call('GET', KEYS[1]) +if not val then + return 'OK' +end +local u = cjson.decode(val) +u.tokens_used = u.tokens_used - tonumber(ARGV[1]) +if u.tokens_used < 0 then + u.tokens_used = 0 +end +redis.call('SET', KEYS[1], cjson.encode(u)) +return 'OK' +`) + +// ReturnTokens adds tokens back to the agent's remaining quota. +func (r *RedisStore) ReturnTokens(agentID string, tokens int64) error { + ctx := context.Background() + periodStart := startOfDay(time.Now().UTC()).Format(time.RFC3339) + err := returnTokensLua.Run(ctx, r.client, + []string{r.usageKey(agentID)}, + tokens, agentID, periodStart, + ).Err() + if err != nil { + return &APIError{Code: 500, Message: "failed to return tokens: " + err.Error()} + } + return nil +} + +// ResetUsage clears usage counters for an agent (daily reset). +func (r *RedisStore) ResetUsage(agentID string) error { + ctx := context.Background() + u := &UsageRecord{ + AgentID: agentID, + PeriodStart: startOfDay(time.Now().UTC()), + } + data, err := json.Marshal(u) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal usage: " + err.Error()} + } + if err := r.client.Set(ctx, r.usageKey(agentID), data, 0).Err(); err != nil { + return &APIError{Code: 500, Message: "failed to reset usage: " + err.Error()} + } + return nil +} + +// GetModelQuota returns global limits for a model. +func (r *RedisStore) GetModelQuota(model string) (*ModelQuota, error) { + ctx := context.Background() + val, err := r.client.Get(ctx, r.modelQuotaKey(model)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, &APIError{Code: 404, Message: "model quota not found: " + model} + } + return nil, &APIError{Code: 500, Message: "failed to get model quota: " + err.Error()} + } + var q ModelQuota + if err := json.Unmarshal([]byte(val), &q); err != nil { + return nil, &APIError{Code: 500, Message: "failed to unmarshal model quota: " + err.Error()} + } + return &q, nil +} + +// GetModelUsage returns current token usage for a model. +func (r *RedisStore) GetModelUsage(model string) (int64, error) { + ctx := context.Background() + val, err := r.client.Get(ctx, r.modelUsageKey(model)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return 0, nil + } + return 0, &APIError{Code: 500, Message: "failed to get model usage: " + err.Error()} + } + var tokens int64 + if err := json.Unmarshal([]byte(val), &tokens); err != nil { + return 0, &APIError{Code: 500, Message: "failed to unmarshal model usage: " + err.Error()} + } + return tokens, nil +} + +// incrementModelUsageLua atomically increments the model usage counter. +// KEYS[1] = model usage key +// ARGV[1] = tokens to add +var incrementModelUsageLua = redis.NewScript(` +local val = redis.call('GET', KEYS[1]) +local current = 0 +if val then + current = tonumber(val) +end +current = current + tonumber(ARGV[1]) +redis.call('SET', KEYS[1], tostring(current)) +return current +`) + +// IncrementModelUsage atomically adds to a model's usage counter. +func (r *RedisStore) IncrementModelUsage(model string, tokens int64) error { + ctx := context.Background() + err := incrementModelUsageLua.Run(ctx, r.client, + []string{r.modelUsageKey(model)}, + tokens, + ).Err() + if err != nil { + return &APIError{Code: 500, Message: "failed to increment model usage: " + err.Error()} + } + return nil +} + +// SetModelQuota persists global limits for a model. +func (r *RedisStore) SetModelQuota(q *ModelQuota) error { + ctx := context.Background() + data, err := json.Marshal(q) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal model quota: " + err.Error()} + } + if err := r.client.Set(ctx, r.modelQuotaKey(q.Model), data, 0).Err(); err != nil { + return &APIError{Code: 500, Message: "failed to set model quota: " + err.Error()} + } + return nil +} + +// FlushPrefix deletes all keys matching the store's prefix. Useful for testing cleanup. +func (r *RedisStore) FlushPrefix(ctx context.Context) error { + iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator() + for iter.Next(ctx) { + if err := r.client.Del(ctx, iter.Val()).Err(); err != nil { + return err + } + } + return iter.Err() +} diff --git a/allowance_redis_test.go b/allowance_redis_test.go new file mode 100644 index 0000000..339fd7c --- /dev/null +++ b/allowance_redis_test.go @@ -0,0 +1,454 @@ +package agentic + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testRedisAddr = "10.69.69.87:6379" + +// newTestRedisStore creates a RedisStore with a unique prefix for test isolation. +// Skips the test if Redis is unreachable. +func newTestRedisStore(t *testing.T) *RedisStore { + t.Helper() + prefix := fmt.Sprintf("test_%d", time.Now().UnixNano()) + s, err := NewRedisStore(testRedisAddr, WithRedisPrefix(prefix)) + if err != nil { + t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) + } + t.Cleanup(func() { + ctx := context.Background() + _ = s.FlushPrefix(ctx) + _ = s.Close() + }) + return s +} + +// --- SetAllowance / GetAllowance --- + +func TestRedisStore_SetGetAllowance_Good(t *testing.T) { + s := newTestRedisStore(t) + + a := &AgentAllowance{ + AgentID: "agent-1", + DailyTokenLimit: 100000, + DailyJobLimit: 10, + ConcurrentJobs: 2, + MaxJobDuration: 30 * time.Minute, + ModelAllowlist: []string{"claude-sonnet-4-5-20250929"}, + } + + err := s.SetAllowance(a) + require.NoError(t, err) + + got, err := s.GetAllowance("agent-1") + require.NoError(t, err) + assert.Equal(t, a.AgentID, got.AgentID) + assert.Equal(t, a.DailyTokenLimit, got.DailyTokenLimit) + assert.Equal(t, a.DailyJobLimit, got.DailyJobLimit) + assert.Equal(t, a.ConcurrentJobs, got.ConcurrentJobs) + assert.Equal(t, a.MaxJobDuration, got.MaxJobDuration) + assert.Equal(t, a.ModelAllowlist, got.ModelAllowlist) +} + +func TestRedisStore_GetAllowance_Bad_NotFound(t *testing.T) { + s := newTestRedisStore(t) + _, err := s.GetAllowance("nonexistent") + require.Error(t, err) + apiErr, ok := err.(*APIError) + require.True(t, ok, "expected *APIError") + assert.Equal(t, 404, apiErr.Code) + assert.Contains(t, err.Error(), "allowance not found") +} + +func TestRedisStore_SetAllowance_Good_Overwrite(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 100}) + _ = s.SetAllowance(&AgentAllowance{AgentID: "agent-1", DailyTokenLimit: 200}) + + got, err := s.GetAllowance("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(200), got.DailyTokenLimit) +} + +// --- GetUsage / IncrementUsage --- + +func TestRedisStore_GetUsage_Good_Default(t *testing.T) { + s := newTestRedisStore(t) + + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.Equal(t, "agent-1", u.AgentID) + assert.Equal(t, int64(0), u.TokensUsed) + assert.Equal(t, 0, u.JobsStarted) + assert.Equal(t, 0, u.ActiveJobs) +} + +func TestRedisStore_IncrementUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + + err := s.IncrementUsage("agent-1", 5000, 1) + require.NoError(t, err) + + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(5000), u.TokensUsed) + assert.Equal(t, 1, u.JobsStarted) + assert.Equal(t, 1, u.ActiveJobs) +} + +func TestRedisStore_IncrementUsage_Good_Accumulates(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementUsage("agent-1", 1000, 1) + _ = s.IncrementUsage("agent-1", 2000, 1) + _ = s.IncrementUsage("agent-1", 3000, 0) + + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(6000), u.TokensUsed) + assert.Equal(t, 2, u.JobsStarted) + assert.Equal(t, 2, u.ActiveJobs) +} + +// --- DecrementActiveJobs --- + +func TestRedisStore_DecrementActiveJobs_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementUsage("agent-1", 0, 2) + _ = s.DecrementActiveJobs("agent-1") + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, 1, u.ActiveJobs) +} + +func TestRedisStore_DecrementActiveJobs_Good_FloorAtZero(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.DecrementActiveJobs("agent-1") // no usage record yet + _ = s.IncrementUsage("agent-1", 0, 0) + _ = s.DecrementActiveJobs("agent-1") // should stay at 0 + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, 0, u.ActiveJobs) +} + +// --- ReturnTokens --- + +func TestRedisStore_ReturnTokens_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementUsage("agent-1", 10000, 0) + err := s.ReturnTokens("agent-1", 5000) + require.NoError(t, err) + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, int64(5000), u.TokensUsed) +} + +func TestRedisStore_ReturnTokens_Good_FloorAtZero(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementUsage("agent-1", 1000, 0) + _ = s.ReturnTokens("agent-1", 5000) // more than used + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, int64(0), u.TokensUsed) +} + +func TestRedisStore_ReturnTokens_Good_NoRecord(t *testing.T) { + s := newTestRedisStore(t) + + // Return tokens for agent with no usage record -- should be a no-op + err := s.ReturnTokens("agent-1", 500) + require.NoError(t, err) + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, int64(0), u.TokensUsed) +} + +// --- ResetUsage --- + +func TestRedisStore_ResetUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementUsage("agent-1", 50000, 5) + + err := s.ResetUsage("agent-1") + require.NoError(t, err) + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, int64(0), u.TokensUsed) + assert.Equal(t, 0, u.JobsStarted) + assert.Equal(t, 0, u.ActiveJobs) +} + +// --- ModelQuota --- + +func TestRedisStore_GetModelQuota_Bad_NotFound(t *testing.T) { + s := newTestRedisStore(t) + _, err := s.GetModelQuota("nonexistent") + require.Error(t, err) + apiErr, ok := err.(*APIError) + require.True(t, ok, "expected *APIError") + assert.Equal(t, 404, apiErr.Code) + assert.Contains(t, err.Error(), "model quota not found") +} + +func TestRedisStore_SetGetModelQuota_Good(t *testing.T) { + s := newTestRedisStore(t) + + q := &ModelQuota{ + Model: "claude-opus-4-6", + DailyTokenBudget: 500000, + HourlyRateLimit: 100, + CostCeiling: 10000, + } + err := s.SetModelQuota(q) + require.NoError(t, err) + + got, err := s.GetModelQuota("claude-opus-4-6") + require.NoError(t, err) + assert.Equal(t, q.Model, got.Model) + assert.Equal(t, q.DailyTokenBudget, got.DailyTokenBudget) + assert.Equal(t, q.HourlyRateLimit, got.HourlyRateLimit) + assert.Equal(t, q.CostCeiling, got.CostCeiling) +} + +// --- ModelUsage --- + +func TestRedisStore_ModelUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.IncrementModelUsage("claude-sonnet", 10000) + _ = s.IncrementModelUsage("claude-sonnet", 5000) + + usage, err := s.GetModelUsage("claude-sonnet") + require.NoError(t, err) + assert.Equal(t, int64(15000), usage) +} + +func TestRedisStore_GetModelUsage_Good_Default(t *testing.T) { + s := newTestRedisStore(t) + + usage, err := s.GetModelUsage("unknown-model") + require.NoError(t, err) + assert.Equal(t, int64(0), usage) +} + +// --- Persistence: set, get, verify --- + +func TestRedisStore_Persistence_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.SetAllowance(&AgentAllowance{ + AgentID: "agent-1", + DailyTokenLimit: 100000, + MaxJobDuration: 15 * time.Minute, + }) + _ = s.IncrementUsage("agent-1", 25000, 3) + _ = s.SetModelQuota(&ModelQuota{Model: "claude-opus-4-6", DailyTokenBudget: 500000}) + _ = s.IncrementModelUsage("claude-opus-4-6", 42000) + + // Verify all data persists (same connection, but data is in Redis) + a, err := s.GetAllowance("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(100000), a.DailyTokenLimit) + assert.Equal(t, 15*time.Minute, a.MaxJobDuration) + + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(25000), u.TokensUsed) + assert.Equal(t, 3, u.JobsStarted) + assert.Equal(t, 3, u.ActiveJobs) + + q, err := s.GetModelQuota("claude-opus-4-6") + require.NoError(t, err) + assert.Equal(t, int64(500000), q.DailyTokenBudget) + + mu, err := s.GetModelUsage("claude-opus-4-6") + require.NoError(t, err) + assert.Equal(t, int64(42000), mu) +} + +// --- Concurrent access --- + +func TestRedisStore_ConcurrentIncrementUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + + const goroutines = 10 + const tokensEach = 1000 + + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + err := s.IncrementUsage("agent-1", tokensEach, 1) + assert.NoError(t, err) + }() + } + wg.Wait() + + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.Equal(t, int64(goroutines*tokensEach), u.TokensUsed) + assert.Equal(t, goroutines, u.JobsStarted) + assert.Equal(t, goroutines, u.ActiveJobs) +} + +func TestRedisStore_ConcurrentModelUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + + const goroutines = 10 + const tokensEach int64 = 500 + + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + err := s.IncrementModelUsage("claude-opus-4-6", tokensEach) + assert.NoError(t, err) + }() + } + wg.Wait() + + usage, err := s.GetModelUsage("claude-opus-4-6") + require.NoError(t, err) + assert.Equal(t, goroutines*tokensEach, usage) +} + +func TestRedisStore_ConcurrentMixed_Good(t *testing.T) { + s := newTestRedisStore(t) + + _ = s.SetAllowance(&AgentAllowance{ + AgentID: "agent-1", + DailyTokenLimit: 1000000, + DailyJobLimit: 100, + ConcurrentJobs: 50, + }) + + const goroutines = 10 + var wg sync.WaitGroup + wg.Add(goroutines * 3) + + // Increment usage + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + _ = s.IncrementUsage("agent-1", 100, 1) + }() + } + + // Decrement active jobs + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + _ = s.DecrementActiveJobs("agent-1") + }() + } + + // Return tokens + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + _ = s.ReturnTokens("agent-1", 10) + }() + } + + wg.Wait() + + // Verify no panics and data is consistent + u, err := s.GetUsage("agent-1") + require.NoError(t, err) + assert.GreaterOrEqual(t, u.TokensUsed, int64(0)) + assert.GreaterOrEqual(t, u.ActiveJobs, 0) +} + +// --- AllowanceService integration via RedisStore --- + +func TestRedisStore_AllowanceServiceCheck_Good(t *testing.T) { + s := newTestRedisStore(t) + svc := NewAllowanceService(s) + + _ = s.SetAllowance(&AgentAllowance{ + AgentID: "agent-1", + DailyTokenLimit: 100000, + DailyJobLimit: 10, + ConcurrentJobs: 2, + }) + + result, err := svc.Check("agent-1", "") + require.NoError(t, err) + assert.True(t, result.Allowed) + assert.Equal(t, AllowanceOK, result.Status) +} + +func TestRedisStore_AllowanceServiceRecordUsage_Good(t *testing.T) { + s := newTestRedisStore(t) + svc := NewAllowanceService(s) + + _ = s.SetAllowance(&AgentAllowance{ + AgentID: "agent-1", + DailyTokenLimit: 100000, + }) + + // Start job + err := svc.RecordUsage(UsageReport{ + AgentID: "agent-1", + JobID: "job-1", + Event: QuotaEventJobStarted, + }) + require.NoError(t, err) + + // Complete job + err = svc.RecordUsage(UsageReport{ + AgentID: "agent-1", + JobID: "job-1", + Model: "claude-sonnet", + TokensIn: 1000, + TokensOut: 500, + Event: QuotaEventJobCompleted, + }) + require.NoError(t, err) + + u, _ := s.GetUsage("agent-1") + assert.Equal(t, int64(1500), u.TokensUsed) + assert.Equal(t, 0, u.ActiveJobs) +} + +// --- Config-based factory with redis backend --- + +func TestNewAllowanceStoreFromConfig_Good_Redis(t *testing.T) { + cfg := AllowanceConfig{ + StoreBackend: "redis", + RedisAddr: testRedisAddr, + } + s, err := NewAllowanceStoreFromConfig(cfg) + if err != nil { + t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) + } + rs, ok := s.(*RedisStore) + assert.True(t, ok, "expected RedisStore") + _ = rs.Close() +} + +// --- Constructor error case --- + +func TestNewRedisStore_Bad_Unreachable(t *testing.T) { + _, err := NewRedisStore("127.0.0.1:1") // almost certainly unreachable + require.Error(t, err) + apiErr, ok := err.(*APIError) + require.True(t, ok, "expected *APIError") + assert.Equal(t, 500, apiErr.Code) + assert.Contains(t, err.Error(), "failed to connect to Redis") +} diff --git a/allowance_sqlite_test.go b/allowance_sqlite_test.go index dfdb16c..2f22ac8 100644 --- a/allowance_sqlite_test.go +++ b/allowance_sqlite_test.go @@ -451,7 +451,7 @@ func TestNewAllowanceStoreFromConfig_Good_SQLite(t *testing.T) { } func TestNewAllowanceStoreFromConfig_Bad_UnknownBackend(t *testing.T) { - cfg := AllowanceConfig{StoreBackend: "redis"} + cfg := AllowanceConfig{StoreBackend: "cassandra"} _, err := NewAllowanceStoreFromConfig(cfg) require.Error(t, err) assert.Contains(t, err.Error(), "unsupported store backend") diff --git a/config.go b/config.go index 7c785dc..9e83557 100644 --- a/config.go +++ b/config.go @@ -198,11 +198,13 @@ func ConfigPath() (string, error) { // AllowanceConfig controls allowance store backend selection. type AllowanceConfig struct { - // StoreBackend is the storage backend: "memory" or "sqlite". Default: "memory". + // StoreBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory". StoreBackend string `yaml:"store_backend" json:"store_backend"` // StorePath is the file path for the SQLite database. // Default: ~/.config/agentic/allowance.db (only used when StoreBackend is "sqlite"). StorePath string `yaml:"store_path" json:"store_path"` + // RedisAddr is the host:port for the Redis server (only used when StoreBackend is "redis"). + RedisAddr string `yaml:"redis_addr" json:"redis_addr"` } // DefaultAllowanceStorePath returns the default SQLite path for allowance data. @@ -230,6 +232,8 @@ func NewAllowanceStoreFromConfig(cfg AllowanceConfig) (AllowanceStore, error) { } } return NewSQLiteStore(dbPath) + case "redis": + return NewRedisStore(cfg.RedisAddr) default: return nil, &APIError{ Code: 400, diff --git a/go.mod b/go.mod index f1b0385..77c8430 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,15 @@ go 1.25.5 require ( forge.lthn.ai/core/go v0.0.0 forge.lthn.ai/core/go-store v0.0.0-00010101000000-000000000000 + github.com/redis/go-redis/v9 v9.18.0 github.com/stretchr/testify v1.11.1 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -18,6 +21,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + go.uber.org/atomic v1.11.0 // indirect golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a // indirect golang.org/x/sys v0.41.0 // indirect modernc.org/libc v1.67.7 // indirect diff --git a/go.sum b/go.sum index c422646..5754184 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,14 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= @@ -9,6 +17,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -20,6 +30,8 @@ github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJm github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.18.0 h1:pMkxYPkEbMPwRdenAzUNyFNrDgHx9U+DrBabWNfSRQs= +github.com/redis/go-redis/v9 v9.18.0/go.mod h1:k3ufPphLU5YXwNTUcCRXGxUoF1fqxnhFQmscfkCoDA0= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= @@ -27,6 +39,10 @@ github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0t github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a h1:ovFr6Z0MNmU7nH8VaX5xqw+05ST2uO1exVfZPVqRC5o= golang.org/x/exp v0.0.0-20260212183809-81e46e3db34a/go.mod h1:K79w1Vqn7PoiZn+TkNpx3BUWUQksGO3JcVX6qIjytmA= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8=