diff --git a/TODO.md b/TODO.md index 540b4f1..8eed18d 100644 --- a/TODO.md +++ b/TODO.md @@ -60,31 +60,31 @@ Phase 4 provides the data-fetching and formatting functions that `core agent` CL - [x] **Create `logs.go`** — `StreamLogs(ctx, client, taskID, writer) error` — polls task updates and writes progress to io.Writer - [x] **Tests** — mock client with progress updates, context cancellation -## Phase 5: Persistent Agent Registry +## Phase 5: Persistent Agent Registry — `04a30df` The `AgentRegistry` interface only has `MemoryRegistry` — a restart drops all agent registrations. This mirrors the AllowanceStore pattern: memory → SQLite → Redis. ### 5.1 SQLite Registry -- [ ] **Create `registry_sqlite.go`** — `SQLiteRegistry` implementing `AgentRegistry` interface -- [ ] Schema: `agents` table (id TEXT PK, name TEXT, capabilities TEXT JSON, status INT, last_heartbeat DATETIME, current_load INT, max_load INT, registered_at DATETIME) -- [ ] Use `modernc.org/sqlite` (already a transitive dep via go-store) with WAL mode -- [ ] `Register` → UPSERT, `Deregister` → DELETE, `Get` → SELECT, `List` → SELECT all, `Heartbeat` → UPDATE last_heartbeat, `Reap(ttl)` → UPDATE status=Offline WHERE last_heartbeat < now-ttl RETURNING id -- [ ] **Tests** — full parity with `registry_test.go` using `:memory:` SQLite, concurrent access under `-race` +- [x] **Create `registry_sqlite.go`** — `SQLiteRegistry` implementing `AgentRegistry` interface +- [x] Schema: `agents` table (id TEXT PK, name TEXT, capabilities TEXT JSON, status TEXT, last_heartbeat DATETIME, current_load INT, max_load INT, registered_at DATETIME) +- [x] Use `modernc.org/sqlite` (already a transitive dep via go-store) with WAL mode +- [x] `Register` → UPSERT, `Deregister` → DELETE, `Get` → SELECT, `List` → SELECT all, `Heartbeat` → UPDATE last_heartbeat, `Reap(ttl)` → UPDATE status=Offline WHERE last_heartbeat < now-ttl +- [x] **Tests** — full parity with `registry_test.go` using `:memory:` SQLite, concurrent access under `-race` ### 5.2 Redis Registry -- [ ] **Create `registry_redis.go`** — `RedisRegistry` implementing `AgentRegistry` with TTL-based reaping -- [ ] Key pattern: `{prefix}:agent:{id}` → JSON AgentInfo, with TTL = heartbeat interval * 3 -- [ ] `Heartbeat` → re-SET with TTL refresh (natural expiry = auto-reap) -- [ ] `List` → SCAN `{prefix}:agent:*`, `Reap` → explicit scan for expired (backup to natural TTL) -- [ ] **Tests** — skip-if-no-Redis pattern, unique prefix per test +- [x] **Create `registry_redis.go`** — `RedisRegistry` implementing `AgentRegistry` with TTL-based reaping +- [x] Key pattern: `{prefix}:agent:{id}` → JSON AgentInfo, with TTL = heartbeat interval * 3 +- [x] `Heartbeat` → re-SET with TTL refresh (natural expiry = auto-reap) +- [x] `List` → SCAN `{prefix}:agent:*`, `Reap` → explicit scan for expired (backup to natural TTL) +- [x] **Tests** — skip-if-no-Redis pattern, unique prefix per test ### 5.3 Config Factory -- [ ] **Add `RegistryConfig`** to `config.go` — `RegistryBackend string` (memory/sqlite/redis), `RegistryPath string`, `RegistryRedisAddr string` -- [ ] **`NewAgentRegistryFromConfig(cfg) (AgentRegistry, error)`** — factory mirroring `NewAllowanceStoreFromConfig` -- [ ] **Tests** — all backends, unknown backend error +- [x] **Add `RegistryConfig`** to `config.go` — `RegistryBackend string` (memory/sqlite/redis), `RegistryPath string`, `RegistryRedisAddr string` +- [x] **`NewAgentRegistryFromConfig(cfg) (AgentRegistry, error)`** — factory mirroring `NewAllowanceStoreFromConfig` +- [x] **Tests** — all backends, unknown backend error ## Phase 6: Dead Code Cleanup + Rate Enforcement diff --git a/config.go b/config.go index e0e91e0..a201e21 100644 --- a/config.go +++ b/config.go @@ -242,3 +242,50 @@ func NewAllowanceStoreFromConfig(cfg AllowanceConfig) (AllowanceStore, error) { } } } + +// RegistryConfig controls agent registry backend selection. +type RegistryConfig struct { + // RegistryBackend is the storage backend: "memory", "sqlite", or "redis". Default: "memory". + RegistryBackend string `yaml:"registry_backend" json:"registry_backend"` + // RegistryPath is the file path for the SQLite database. + // Default: ~/.config/agentic/registry.db (only used when RegistryBackend is "sqlite"). + RegistryPath string `yaml:"registry_path" json:"registry_path"` + // RegistryRedisAddr is the host:port for the Redis server (only used when RegistryBackend is "redis"). + RegistryRedisAddr string `yaml:"registry_redis_addr" json:"registry_redis_addr"` +} + +// DefaultRegistryPath returns the default SQLite path for registry data. +func DefaultRegistryPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", errors.E("agentic.DefaultRegistryPath", "failed to get home directory", err) + } + return filepath.Join(homeDir, ".config", "agentic", "registry.db"), nil +} + +// NewAgentRegistryFromConfig creates an AgentRegistry based on the given config. +// It returns a MemoryRegistry for "memory" (or empty) backend, a SQLiteRegistry +// for "sqlite", and a RedisRegistry for "redis". +func NewAgentRegistryFromConfig(cfg RegistryConfig) (AgentRegistry, error) { + switch cfg.RegistryBackend { + case "", "memory": + return NewMemoryRegistry(), nil + case "sqlite": + dbPath := cfg.RegistryPath + if dbPath == "" { + var err error + dbPath, err = DefaultRegistryPath() + if err != nil { + return nil, err + } + } + return NewSQLiteRegistry(dbPath) + case "redis": + return NewRedisRegistry(cfg.RegistryRedisAddr) + default: + return nil, &APIError{ + Code: 400, + Message: "unsupported registry backend: " + cfg.RegistryBackend, + } + } +} diff --git a/registry_redis.go b/registry_redis.go new file mode 100644 index 0000000..20b533e --- /dev/null +++ b/registry_redis.go @@ -0,0 +1,270 @@ +package agentic + +import ( + "context" + "encoding/json" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +// RedisRegistry implements AgentRegistry using Redis as the backing store. +// It provides persistent, network-accessible agent registration suitable for +// multi-node deployments. Heartbeat refreshes key TTL for natural reaping via +// expiry. +type RedisRegistry struct { + client *redis.Client + prefix string + defaultTTL time.Duration +} + +// redisRegistryConfig holds the configuration for a RedisRegistry. +type redisRegistryConfig struct { + password string + db int + prefix string + ttl time.Duration +} + +// RedisRegistryOption is a functional option for configuring a RedisRegistry. +type RedisRegistryOption func(*redisRegistryConfig) + +// WithRegistryRedisPassword sets the password for authenticating with Redis. +func WithRegistryRedisPassword(pw string) RedisRegistryOption { + return func(c *redisRegistryConfig) { + c.password = pw + } +} + +// WithRegistryRedisDB selects the Redis database number. +func WithRegistryRedisDB(db int) RedisRegistryOption { + return func(c *redisRegistryConfig) { + c.db = db + } +} + +// WithRegistryRedisPrefix sets the key prefix for all Redis keys. +// Default: "agentic". +func WithRegistryRedisPrefix(prefix string) RedisRegistryOption { + return func(c *redisRegistryConfig) { + c.prefix = prefix + } +} + +// WithRegistryTTL sets the default TTL for agent keys. Default: 5 minutes. +// Heartbeat refreshes this TTL. Agents whose keys expire are naturally reaped. +func WithRegistryTTL(ttl time.Duration) RedisRegistryOption { + return func(c *redisRegistryConfig) { + c.ttl = ttl + } +} + +// NewRedisRegistry creates a new Redis-backed agent registry connecting to the +// given address (host:port). It pings the server to verify connectivity. +func NewRedisRegistry(addr string, opts ...RedisRegistryOption) (*RedisRegistry, error) { + cfg := &redisRegistryConfig{ + prefix: "agentic", + ttl: 5 * time.Minute, + } + for _, opt := range opts { + opt(cfg) + } + + client := redis.NewClient(&redis.Options{ + Addr: addr, + Password: cfg.password, + DB: cfg.db, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + _ = client.Close() + return nil, &APIError{Code: 500, Message: "failed to connect to Redis: " + err.Error()} + } + + return &RedisRegistry{ + client: client, + prefix: cfg.prefix, + defaultTTL: cfg.ttl, + }, nil +} + +// Close releases the underlying Redis connection. +func (r *RedisRegistry) Close() error { + return r.client.Close() +} + +// --- key helpers --- + +func (r *RedisRegistry) agentKey(id string) string { + return r.prefix + ":agent:" + id +} + +func (r *RedisRegistry) agentPattern() string { + return r.prefix + ":agent:*" +} + +// --- AgentRegistry interface --- + +// Register adds or updates an agent in the registry. +func (r *RedisRegistry) Register(agent AgentInfo) error { + if agent.ID == "" { + return &APIError{Code: 400, Message: "agent ID is required"} + } + ctx := context.Background() + data, err := json.Marshal(agent) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()} + } + if err := r.client.Set(ctx, r.agentKey(agent.ID), data, r.defaultTTL).Err(); err != nil { + return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()} + } + return nil +} + +// Deregister removes an agent from the registry. Returns an error if the agent +// is not found. +func (r *RedisRegistry) Deregister(id string) error { + ctx := context.Background() + n, err := r.client.Del(ctx, r.agentKey(id)).Result() + if err != nil { + return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()} + } + if n == 0 { + return &APIError{Code: 404, Message: "agent not found: " + id} + } + return nil +} + +// Get returns a copy of the agent info for the given ID. Returns an error if +// the agent is not found. +func (r *RedisRegistry) Get(id string) (AgentInfo, error) { + ctx := context.Background() + val, err := r.client.Get(ctx, r.agentKey(id)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + id} + } + return AgentInfo{}, &APIError{Code: 500, Message: "failed to get agent: " + err.Error()} + } + var a AgentInfo + if err := json.Unmarshal([]byte(val), &a); err != nil { + return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()} + } + return a, nil +} + +// List returns a copy of all registered agents by scanning all agent keys. +func (r *RedisRegistry) List() []AgentInfo { + ctx := context.Background() + var result []AgentInfo + + iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator() + for iter.Next(ctx) { + val, err := r.client.Get(ctx, iter.Val()).Result() + if err != nil { + continue + } + var a AgentInfo + if err := json.Unmarshal([]byte(val), &a); err != nil { + continue + } + result = append(result, a) + } + + if result == nil { + return []AgentInfo{} + } + return result +} + +// Heartbeat updates the agent's LastHeartbeat timestamp and refreshes the key +// TTL. If the agent was Offline, it transitions to Available. +func (r *RedisRegistry) Heartbeat(id string) error { + ctx := context.Background() + key := r.agentKey(id) + + val, err := r.client.Get(ctx, key).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return &APIError{Code: 404, Message: "agent not found: " + id} + } + return &APIError{Code: 500, Message: "failed to get agent for heartbeat: " + err.Error()} + } + + var a AgentInfo + if err := json.Unmarshal([]byte(val), &a); err != nil { + return &APIError{Code: 500, Message: "failed to unmarshal agent: " + err.Error()} + } + + a.LastHeartbeat = time.Now().UTC() + if a.Status == AgentOffline { + a.Status = AgentAvailable + } + + data, err := json.Marshal(a) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal agent: " + err.Error()} + } + + if err := r.client.Set(ctx, key, data, r.defaultTTL).Err(); err != nil { + return &APIError{Code: 500, Message: "failed to update agent heartbeat: " + err.Error()} + } + return nil +} + +// Reap scans all agent keys and marks agents as Offline if their last heartbeat +// is older than ttl. This is a backup to natural TTL expiry. Returns the IDs +// of agents that were reaped. +func (r *RedisRegistry) Reap(ttl time.Duration) []string { + ctx := context.Background() + cutoff := time.Now().UTC().Add(-ttl) + var reaped []string + + iter := r.client.Scan(ctx, 0, r.agentPattern(), 100).Iterator() + for iter.Next(ctx) { + key := iter.Val() + val, err := r.client.Get(ctx, key).Result() + if err != nil { + continue + } + var a AgentInfo + if err := json.Unmarshal([]byte(val), &a); err != nil { + continue + } + + if a.Status != AgentOffline && a.LastHeartbeat.Before(cutoff) { + a.Status = AgentOffline + data, err := json.Marshal(a) + if err != nil { + continue + } + // Preserve remaining TTL (or use default if none). + remainingTTL, err := r.client.TTL(ctx, key).Result() + if err != nil || remainingTTL <= 0 { + remainingTTL = r.defaultTTL + } + if err := r.client.Set(ctx, key, data, remainingTTL).Err(); err != nil { + continue + } + reaped = append(reaped, a.ID) + } + } + + return reaped +} + +// FlushPrefix deletes all keys matching the registry's prefix. Useful for +// testing cleanup. +func (r *RedisRegistry) FlushPrefix(ctx context.Context) error { + iter := r.client.Scan(ctx, 0, r.prefix+":*", 100).Iterator() + for iter.Next(ctx) { + if err := r.client.Del(ctx, iter.Val()).Err(); err != nil { + return err + } + } + return iter.Err() +} diff --git a/registry_redis_test.go b/registry_redis_test.go new file mode 100644 index 0000000..928321e --- /dev/null +++ b/registry_redis_test.go @@ -0,0 +1,327 @@ +package agentic + +import ( + "context" + "fmt" + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestRedisRegistry creates a RedisRegistry with a unique prefix for test isolation. +// Skips the test if Redis is unreachable. +func newTestRedisRegistry(t *testing.T) *RedisRegistry { + t.Helper() + prefix := fmt.Sprintf("test_reg_%d", time.Now().UnixNano()) + reg, err := NewRedisRegistry(testRedisAddr, + WithRegistryRedisPrefix(prefix), + WithRegistryTTL(5*time.Minute), + ) + if err != nil { + t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) + } + t.Cleanup(func() { + ctx := context.Background() + _ = reg.FlushPrefix(ctx) + _ = reg.Close() + }) + return reg +} + +// --- Register tests --- + +func TestRedisRegistry_Register_Good(t *testing.T) { + reg := newTestRedisRegistry(t) + err := reg.Register(AgentInfo{ + ID: "agent-1", + Name: "Test Agent", + Capabilities: []string{"go", "testing"}, + Status: AgentAvailable, + MaxLoad: 5, + }) + require.NoError(t, err) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "agent-1", got.ID) + assert.Equal(t, "Test Agent", got.Name) + assert.Equal(t, []string{"go", "testing"}, got.Capabilities) + assert.Equal(t, AgentAvailable, got.Status) + assert.Equal(t, 5, got.MaxLoad) +} + +func TestRedisRegistry_Register_Good_Overwrite(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3}) + err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10}) + require.NoError(t, err) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "Updated", got.Name) + assert.Equal(t, 10, got.MaxLoad) +} + +func TestRedisRegistry_Register_Bad_EmptyID(t *testing.T) { + reg := newTestRedisRegistry(t) + err := reg.Register(AgentInfo{ID: "", Name: "No ID"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "agent ID is required") +} + +// --- Deregister tests --- + +func TestRedisRegistry_Deregister_Good(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"}) + + err := reg.Deregister("agent-1") + require.NoError(t, err) + + _, err = reg.Get("agent-1") + require.Error(t, err) +} + +func TestRedisRegistry_Deregister_Bad_NotFound(t *testing.T) { + reg := newTestRedisRegistry(t) + err := reg.Deregister("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +// --- Get tests --- + +func TestRedisRegistry_Get_Good(t *testing.T) { + reg := newTestRedisRegistry(t) + now := time.Now().UTC().Truncate(time.Millisecond) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Name: "Getter", + Status: AgentBusy, + CurrentLoad: 2, + MaxLoad: 5, + LastHeartbeat: now, + }) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, AgentBusy, got.Status) + assert.Equal(t, 2, got.CurrentLoad) + assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) +} + +func TestRedisRegistry_Get_Bad_NotFound(t *testing.T) { + reg := newTestRedisRegistry(t) + _, err := reg.Get("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +func TestRedisRegistry_Get_Good_ReturnsCopy(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1}) + + got, _ := reg.Get("agent-1") + got.CurrentLoad = 99 + got.Name = "Tampered" + + // Re-read — should be unchanged (deserialized from Redis). + again, _ := reg.Get("agent-1") + assert.Equal(t, "Original", again.Name) + assert.Equal(t, 1, again.CurrentLoad) +} + +// --- List tests --- + +func TestRedisRegistry_List_Good_Empty(t *testing.T) { + reg := newTestRedisRegistry(t) + agents := reg.List() + assert.Empty(t, agents) +} + +func TestRedisRegistry_List_Good_Multiple(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"}) + _ = reg.Register(AgentInfo{ID: "b", Name: "Beta"}) + _ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"}) + + agents := reg.List() + assert.Len(t, agents, 3) + + // Sort by ID for deterministic assertion. + sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID }) + assert.Equal(t, "a", agents[0].ID) + assert.Equal(t, "b", agents[1].ID) + assert.Equal(t, "c", agents[2].ID) +} + +// --- Heartbeat tests --- + +func TestRedisRegistry_Heartbeat_Good(t *testing.T) { + reg := newTestRedisRegistry(t) + past := time.Now().UTC().Add(-5 * time.Minute) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentAvailable, + LastHeartbeat: past, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.True(t, got.LastHeartbeat.After(past)) + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestRedisRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentOffline, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestRedisRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) { + reg := newTestRedisRegistry(t) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentBusy, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.Equal(t, AgentBusy, got.Status) +} + +func TestRedisRegistry_Heartbeat_Bad_NotFound(t *testing.T) { + reg := newTestRedisRegistry(t) + err := reg.Heartbeat("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +// --- Reap tests --- + +func TestRedisRegistry_Reap_Good_StaleAgent(t *testing.T) { + reg := newTestRedisRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + fresh := time.Now().UTC() + + _ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale}) + _ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh}) + + reaped := reg.Reap(5 * time.Minute) + assert.Len(t, reaped, 1) + assert.Contains(t, reaped, "stale-1") + + got, _ := reg.Get("stale-1") + assert.Equal(t, AgentOffline, got.Status) + + got, _ = reg.Get("fresh-1") + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestRedisRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) { + reg := newTestRedisRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + + _ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale}) + + reaped := reg.Reap(5 * time.Minute) + assert.Empty(t, reaped) +} + +func TestRedisRegistry_Reap_Good_NoStaleAgents(t *testing.T) { + reg := newTestRedisRegistry(t) + now := time.Now().UTC() + + _ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now}) + _ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now}) + + reaped := reg.Reap(5 * time.Minute) + assert.Empty(t, reaped) +} + +func TestRedisRegistry_Reap_Good_BusyAgentReaped(t *testing.T) { + reg := newTestRedisRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + + _ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale}) + + reaped := reg.Reap(5 * time.Minute) + assert.Len(t, reaped, 1) + assert.Contains(t, reaped, "busy-stale") + + got, _ := reg.Get("busy-stale") + assert.Equal(t, AgentOffline, got.Status) +} + +// --- Concurrent access --- + +func TestRedisRegistry_Concurrent_Good(t *testing.T) { + reg := newTestRedisRegistry(t) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + id := "agent-" + string(rune('a'+n%5)) + _ = reg.Register(AgentInfo{ + ID: id, + Name: "Concurrent", + Status: AgentAvailable, + LastHeartbeat: time.Now().UTC(), + }) + _, _ = reg.Get(id) + _ = reg.Heartbeat(id) + _ = reg.List() + _ = reg.Reap(1 * time.Minute) + }(i) + } + wg.Wait() + + // No race conditions — test passes under -race. + agents := reg.List() + assert.True(t, len(agents) > 0) +} + +// --- Constructor error case --- + +func TestNewRedisRegistry_Bad_Unreachable(t *testing.T) { + _, err := NewRedisRegistry("127.0.0.1:1") // almost certainly unreachable + require.Error(t, err) + apiErr, ok := err.(*APIError) + require.True(t, ok, "expected *APIError") + assert.Equal(t, 500, apiErr.Code) + assert.Contains(t, err.Error(), "failed to connect to Redis") +} + +// --- Config-based factory with redis backend --- + +func TestNewAgentRegistryFromConfig_Good_Redis(t *testing.T) { + cfg := RegistryConfig{ + RegistryBackend: "redis", + RegistryRedisAddr: testRedisAddr, + } + reg, err := NewAgentRegistryFromConfig(cfg) + if err != nil { + t.Skipf("Redis unavailable at %s: %v", testRedisAddr, err) + } + rr, ok := reg.(*RedisRegistry) + assert.True(t, ok, "expected RedisRegistry") + _ = rr.Close() +} diff --git a/registry_sqlite.go b/registry_sqlite.go new file mode 100644 index 0000000..9f89758 --- /dev/null +++ b/registry_sqlite.go @@ -0,0 +1,250 @@ +package agentic + +import ( + "database/sql" + "encoding/json" + "strings" + "sync" + "time" + + _ "modernc.org/sqlite" +) + +// SQLiteRegistry implements AgentRegistry using a SQLite database. +// It provides persistent storage that survives process restarts. +type SQLiteRegistry struct { + db *sql.DB + mu sync.Mutex // serialises read-modify-write operations +} + +// NewSQLiteRegistry creates a new SQLite-backed agent registry at the given path. +// Use ":memory:" for tests that do not need persistence. +func NewSQLiteRegistry(dbPath string) (*SQLiteRegistry, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, &APIError{Code: 500, Message: "failed to open SQLite registry: " + err.Error()} + } + db.SetMaxOpenConns(1) + if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + db.Close() + return nil, &APIError{Code: 500, Message: "failed to set WAL mode: " + err.Error()} + } + if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { + db.Close() + return nil, &APIError{Code: 500, Message: "failed to set busy_timeout: " + err.Error()} + } + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS agents ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL DEFAULT '', + capabilities TEXT NOT NULL DEFAULT '[]', + status TEXT NOT NULL DEFAULT 'available', + last_heartbeat DATETIME NOT NULL DEFAULT (datetime('now')), + current_load INTEGER NOT NULL DEFAULT 0, + max_load INTEGER NOT NULL DEFAULT 0, + registered_at DATETIME NOT NULL DEFAULT (datetime('now')) + )`); err != nil { + db.Close() + return nil, &APIError{Code: 500, Message: "failed to create agents table: " + err.Error()} + } + return &SQLiteRegistry{db: db}, nil +} + +// Close releases the underlying SQLite database. +func (r *SQLiteRegistry) Close() error { + return r.db.Close() +} + +// Register adds or updates an agent in the registry. Returns an error if the +// agent ID is empty. +func (r *SQLiteRegistry) Register(agent AgentInfo) error { + if agent.ID == "" { + return &APIError{Code: 400, Message: "agent ID is required"} + } + caps, err := json.Marshal(agent.Capabilities) + if err != nil { + return &APIError{Code: 500, Message: "failed to marshal capabilities: " + err.Error()} + } + hb := agent.LastHeartbeat + if hb.IsZero() { + hb = time.Now().UTC() + } + r.mu.Lock() + defer r.mu.Unlock() + _, err = r.db.Exec(`INSERT INTO agents (id, name, capabilities, status, last_heartbeat, current_load, max_load, registered_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + capabilities = excluded.capabilities, + status = excluded.status, + last_heartbeat = excluded.last_heartbeat, + current_load = excluded.current_load, + max_load = excluded.max_load`, + agent.ID, agent.Name, string(caps), string(agent.Status), hb.Format(time.RFC3339Nano), + agent.CurrentLoad, agent.MaxLoad, hb.Format(time.RFC3339Nano)) + if err != nil { + return &APIError{Code: 500, Message: "failed to register agent: " + err.Error()} + } + return nil +} + +// Deregister removes an agent from the registry. Returns an error if the agent +// is not found. +func (r *SQLiteRegistry) Deregister(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + res, err := r.db.Exec("DELETE FROM agents WHERE id = ?", id) + if err != nil { + return &APIError{Code: 500, Message: "failed to deregister agent: " + err.Error()} + } + n, err := res.RowsAffected() + if err != nil { + return &APIError{Code: 500, Message: "failed to check delete result: " + err.Error()} + } + if n == 0 { + return &APIError{Code: 404, Message: "agent not found: " + id} + } + return nil +} + +// Get returns a copy of the agent info for the given ID. Returns an error if +// the agent is not found. +func (r *SQLiteRegistry) Get(id string) (AgentInfo, error) { + return r.scanAgent("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents WHERE id = ?", id) +} + +// List returns a copy of all registered agents. +func (r *SQLiteRegistry) List() []AgentInfo { + rows, err := r.db.Query("SELECT id, name, capabilities, status, last_heartbeat, current_load, max_load FROM agents") + if err != nil { + return nil + } + defer rows.Close() + + var result []AgentInfo + for rows.Next() { + a, err := r.scanAgentRow(rows) + if err != nil { + continue + } + result = append(result, a) + } + if result == nil { + return []AgentInfo{} + } + return result +} + +// Heartbeat updates the agent's LastHeartbeat timestamp. If the agent was +// Offline, it transitions to Available. +func (r *SQLiteRegistry) Heartbeat(id string) error { + r.mu.Lock() + defer r.mu.Unlock() + + now := time.Now().UTC().Format(time.RFC3339Nano) + + // Update heartbeat for all agents, and transition offline agents to available. + res, err := r.db.Exec(`UPDATE agents SET + last_heartbeat = ?, + status = CASE WHEN status = ? THEN ? ELSE status END + WHERE id = ?`, + now, string(AgentOffline), string(AgentAvailable), id) + if err != nil { + return &APIError{Code: 500, Message: "failed to heartbeat agent: " + err.Error()} + } + n, err := res.RowsAffected() + if err != nil { + return &APIError{Code: 500, Message: "failed to check heartbeat result: " + err.Error()} + } + if n == 0 { + return &APIError{Code: 404, Message: "agent not found: " + id} + } + return nil +} + +// Reap marks agents as Offline if their last heartbeat is older than ttl. +// Returns the IDs of agents that were reaped. +func (r *SQLiteRegistry) Reap(ttl time.Duration) []string { + r.mu.Lock() + defer r.mu.Unlock() + + cutoff := time.Now().UTC().Add(-ttl).Format(time.RFC3339Nano) + + // Select agents that will be reaped before updating. + rows, err := r.db.Query( + "SELECT id FROM agents WHERE status != ? AND last_heartbeat < ?", + string(AgentOffline), cutoff) + if err != nil { + return nil + } + defer rows.Close() + + var reaped []string + for rows.Next() { + var id string + if err := rows.Scan(&id); err != nil { + continue + } + reaped = append(reaped, id) + } + if err := rows.Err(); err != nil { + return nil + } + rows.Close() + + if len(reaped) > 0 { + // Build placeholders for IN clause. + placeholders := make([]string, len(reaped)) + args := make([]any, len(reaped)) + for i, id := range reaped { + placeholders[i] = "?" + args[i] = id + } + query := "UPDATE agents SET status = ? WHERE id IN (" + strings.Join(placeholders, ",") + ")" + allArgs := append([]any{string(AgentOffline)}, args...) + _, _ = r.db.Exec(query, allArgs...) + } + + return reaped +} + +// --- internal helpers --- + +// scanAgent executes a query that returns a single agent row. +func (r *SQLiteRegistry) scanAgent(query string, args ...any) (AgentInfo, error) { + row := r.db.QueryRow(query, args...) + var a AgentInfo + var capsJSON string + var statusStr string + var hbStr string + err := row.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad) + if err == sql.ErrNoRows { + return AgentInfo{}, &APIError{Code: 404, Message: "agent not found: " + args[0].(string)} + } + if err != nil { + return AgentInfo{}, &APIError{Code: 500, Message: "failed to scan agent: " + err.Error()} + } + if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil { + return AgentInfo{}, &APIError{Code: 500, Message: "failed to unmarshal capabilities: " + err.Error()} + } + a.Status = AgentStatus(statusStr) + a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr) + return a, nil +} + +// scanAgentRow scans a single row from a rows iterator. +func (r *SQLiteRegistry) scanAgentRow(rows *sql.Rows) (AgentInfo, error) { + var a AgentInfo + var capsJSON string + var statusStr string + var hbStr string + err := rows.Scan(&a.ID, &a.Name, &capsJSON, &statusStr, &hbStr, &a.CurrentLoad, &a.MaxLoad) + if err != nil { + return AgentInfo{}, err + } + if err := json.Unmarshal([]byte(capsJSON), &a.Capabilities); err != nil { + return AgentInfo{}, err + } + a.Status = AgentStatus(statusStr) + a.LastHeartbeat, _ = time.Parse(time.RFC3339Nano, hbStr) + return a, nil +} diff --git a/registry_sqlite_test.go b/registry_sqlite_test.go new file mode 100644 index 0000000..c3b8957 --- /dev/null +++ b/registry_sqlite_test.go @@ -0,0 +1,386 @@ +package agentic + +import ( + "path/filepath" + "sort" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// newTestSQLiteRegistry creates a SQLiteRegistry backed by :memory: for testing. +func newTestSQLiteRegistry(t *testing.T) *SQLiteRegistry { + t.Helper() + reg, err := NewSQLiteRegistry(":memory:") + require.NoError(t, err) + t.Cleanup(func() { _ = reg.Close() }) + return reg +} + +// --- Register tests --- + +func TestSQLiteRegistry_Register_Good(t *testing.T) { + reg := newTestSQLiteRegistry(t) + err := reg.Register(AgentInfo{ + ID: "agent-1", + Name: "Test Agent", + Capabilities: []string{"go", "testing"}, + Status: AgentAvailable, + MaxLoad: 5, + }) + require.NoError(t, err) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "agent-1", got.ID) + assert.Equal(t, "Test Agent", got.Name) + assert.Equal(t, []string{"go", "testing"}, got.Capabilities) + assert.Equal(t, AgentAvailable, got.Status) + assert.Equal(t, 5, got.MaxLoad) +} + +func TestSQLiteRegistry_Register_Good_Overwrite(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", MaxLoad: 3}) + err := reg.Register(AgentInfo{ID: "agent-1", Name: "Updated", MaxLoad: 10}) + require.NoError(t, err) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "Updated", got.Name) + assert.Equal(t, 10, got.MaxLoad) +} + +func TestSQLiteRegistry_Register_Bad_EmptyID(t *testing.T) { + reg := newTestSQLiteRegistry(t) + err := reg.Register(AgentInfo{ID: "", Name: "No ID"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "agent ID is required") +} + +func TestSQLiteRegistry_Register_Good_NilCapabilities(t *testing.T) { + reg := newTestSQLiteRegistry(t) + err := reg.Register(AgentInfo{ + ID: "agent-1", + Name: "No Caps", + Capabilities: nil, + Status: AgentAvailable, + }) + require.NoError(t, err) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "No Caps", got.Name) + // nil capabilities serialised as JSON null, deserialised back to nil. +} + +// --- Deregister tests --- + +func TestSQLiteRegistry_Deregister_Good(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "To Remove"}) + + err := reg.Deregister("agent-1") + require.NoError(t, err) + + _, err = reg.Get("agent-1") + require.Error(t, err) +} + +func TestSQLiteRegistry_Deregister_Bad_NotFound(t *testing.T) { + reg := newTestSQLiteRegistry(t) + err := reg.Deregister("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +// --- Get tests --- + +func TestSQLiteRegistry_Get_Good(t *testing.T) { + reg := newTestSQLiteRegistry(t) + now := time.Now().UTC().Truncate(time.Microsecond) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Name: "Getter", + Status: AgentBusy, + CurrentLoad: 2, + MaxLoad: 5, + LastHeartbeat: now, + }) + + got, err := reg.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, AgentBusy, got.Status) + assert.Equal(t, 2, got.CurrentLoad) + // Heartbeat stored via RFC3339Nano — allow small time difference from serialisation. + assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) +} + +func TestSQLiteRegistry_Get_Bad_NotFound(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _, err := reg.Get("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +func TestSQLiteRegistry_Get_Good_ReturnsCopy(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ID: "agent-1", Name: "Original", CurrentLoad: 1}) + + got, _ := reg.Get("agent-1") + got.CurrentLoad = 99 + got.Name = "Tampered" + + // Re-read — should be unchanged. + again, _ := reg.Get("agent-1") + assert.Equal(t, "Original", again.Name) + assert.Equal(t, 1, again.CurrentLoad) +} + +// --- List tests --- + +func TestSQLiteRegistry_List_Good_Empty(t *testing.T) { + reg := newTestSQLiteRegistry(t) + agents := reg.List() + assert.Empty(t, agents) +} + +func TestSQLiteRegistry_List_Good_Multiple(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ID: "a", Name: "Alpha"}) + _ = reg.Register(AgentInfo{ID: "b", Name: "Beta"}) + _ = reg.Register(AgentInfo{ID: "c", Name: "Charlie"}) + + agents := reg.List() + assert.Len(t, agents, 3) + + // Sort by ID for deterministic assertion. + sort.Slice(agents, func(i, j int) bool { return agents[i].ID < agents[j].ID }) + assert.Equal(t, "a", agents[0].ID) + assert.Equal(t, "b", agents[1].ID) + assert.Equal(t, "c", agents[2].ID) +} + +// --- Heartbeat tests --- + +func TestSQLiteRegistry_Heartbeat_Good(t *testing.T) { + reg := newTestSQLiteRegistry(t) + past := time.Now().UTC().Add(-5 * time.Minute) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentAvailable, + LastHeartbeat: past, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.True(t, got.LastHeartbeat.After(past)) + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestSQLiteRegistry_Heartbeat_Good_RecoverFromOffline(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentOffline, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestSQLiteRegistry_Heartbeat_Good_BusyStaysBusy(t *testing.T) { + reg := newTestSQLiteRegistry(t) + _ = reg.Register(AgentInfo{ + ID: "agent-1", + Status: AgentBusy, + }) + + err := reg.Heartbeat("agent-1") + require.NoError(t, err) + + got, _ := reg.Get("agent-1") + assert.Equal(t, AgentBusy, got.Status) +} + +func TestSQLiteRegistry_Heartbeat_Bad_NotFound(t *testing.T) { + reg := newTestSQLiteRegistry(t) + err := reg.Heartbeat("nonexistent") + require.Error(t, err) + assert.Contains(t, err.Error(), "agent not found") +} + +// --- Reap tests --- + +func TestSQLiteRegistry_Reap_Good_StaleAgent(t *testing.T) { + reg := newTestSQLiteRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + fresh := time.Now().UTC() + + _ = reg.Register(AgentInfo{ID: "stale-1", Status: AgentAvailable, LastHeartbeat: stale}) + _ = reg.Register(AgentInfo{ID: "fresh-1", Status: AgentAvailable, LastHeartbeat: fresh}) + + reaped := reg.Reap(5 * time.Minute) + assert.Len(t, reaped, 1) + assert.Contains(t, reaped, "stale-1") + + got, _ := reg.Get("stale-1") + assert.Equal(t, AgentOffline, got.Status) + + got, _ = reg.Get("fresh-1") + assert.Equal(t, AgentAvailable, got.Status) +} + +func TestSQLiteRegistry_Reap_Good_AlreadyOfflineSkipped(t *testing.T) { + reg := newTestSQLiteRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + + _ = reg.Register(AgentInfo{ID: "already-off", Status: AgentOffline, LastHeartbeat: stale}) + + reaped := reg.Reap(5 * time.Minute) + assert.Empty(t, reaped) +} + +func TestSQLiteRegistry_Reap_Good_NoStaleAgents(t *testing.T) { + reg := newTestSQLiteRegistry(t) + now := time.Now().UTC() + + _ = reg.Register(AgentInfo{ID: "a", Status: AgentAvailable, LastHeartbeat: now}) + _ = reg.Register(AgentInfo{ID: "b", Status: AgentBusy, LastHeartbeat: now}) + + reaped := reg.Reap(5 * time.Minute) + assert.Empty(t, reaped) +} + +func TestSQLiteRegistry_Reap_Good_BusyAgentReaped(t *testing.T) { + reg := newTestSQLiteRegistry(t) + stale := time.Now().UTC().Add(-10 * time.Minute) + + _ = reg.Register(AgentInfo{ID: "busy-stale", Status: AgentBusy, LastHeartbeat: stale}) + + reaped := reg.Reap(5 * time.Minute) + assert.Len(t, reaped, 1) + assert.Contains(t, reaped, "busy-stale") + + got, _ := reg.Get("busy-stale") + assert.Equal(t, AgentOffline, got.Status) +} + +// --- Concurrent access --- + +func TestSQLiteRegistry_Concurrent_Good(t *testing.T) { + reg := newTestSQLiteRegistry(t) + + var wg sync.WaitGroup + for i := 0; i < 20; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + id := "agent-" + string(rune('a'+n%5)) + _ = reg.Register(AgentInfo{ + ID: id, + Name: "Concurrent", + Status: AgentAvailable, + LastHeartbeat: time.Now().UTC(), + }) + _, _ = reg.Get(id) + _ = reg.Heartbeat(id) + _ = reg.List() + _ = reg.Reap(1 * time.Minute) + }(i) + } + wg.Wait() + + // No race conditions — test passes under -race. + agents := reg.List() + assert.True(t, len(agents) > 0) +} + +// --- Persistence: close and reopen --- + +func TestSQLiteRegistry_Persistence_Good(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "registry.db") + + // Phase 1: write data + r1, err := NewSQLiteRegistry(dbPath) + require.NoError(t, err) + + now := time.Now().UTC().Truncate(time.Microsecond) + _ = r1.Register(AgentInfo{ + ID: "agent-1", + Name: "Persistent", + Capabilities: []string{"go", "rust"}, + Status: AgentBusy, + LastHeartbeat: now, + CurrentLoad: 3, + MaxLoad: 10, + }) + require.NoError(t, r1.Close()) + + // Phase 2: reopen and verify + r2, err := NewSQLiteRegistry(dbPath) + require.NoError(t, err) + defer func() { _ = r2.Close() }() + + got, err := r2.Get("agent-1") + require.NoError(t, err) + assert.Equal(t, "Persistent", got.Name) + assert.Equal(t, []string{"go", "rust"}, got.Capabilities) + assert.Equal(t, AgentBusy, got.Status) + assert.Equal(t, 3, got.CurrentLoad) + assert.Equal(t, 10, got.MaxLoad) + assert.WithinDuration(t, now, got.LastHeartbeat, time.Millisecond) +} + +// --- Constructor error case --- + +func TestNewSQLiteRegistry_Bad_InvalidPath(t *testing.T) { + _, err := NewSQLiteRegistry("/nonexistent/deeply/nested/dir/registry.db") + require.Error(t, err) +} + +// --- Config-based factory --- + +func TestNewAgentRegistryFromConfig_Good_Memory(t *testing.T) { + cfg := RegistryConfig{RegistryBackend: "memory"} + reg, err := NewAgentRegistryFromConfig(cfg) + require.NoError(t, err) + _, ok := reg.(*MemoryRegistry) + assert.True(t, ok, "expected MemoryRegistry") +} + +func TestNewAgentRegistryFromConfig_Good_Default(t *testing.T) { + cfg := RegistryConfig{} // empty defaults to memory + reg, err := NewAgentRegistryFromConfig(cfg) + require.NoError(t, err) + _, ok := reg.(*MemoryRegistry) + assert.True(t, ok, "expected MemoryRegistry for empty config") +} + +func TestNewAgentRegistryFromConfig_Good_SQLite(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "factory-registry.db") + cfg := RegistryConfig{ + RegistryBackend: "sqlite", + RegistryPath: dbPath, + } + reg, err := NewAgentRegistryFromConfig(cfg) + require.NoError(t, err) + sr, ok := reg.(*SQLiteRegistry) + assert.True(t, ok, "expected SQLiteRegistry") + _ = sr.Close() +} + +func TestNewAgentRegistryFromConfig_Bad_UnknownBackend(t *testing.T) { + cfg := RegistryConfig{RegistryBackend: "cassandra"} + _, err := NewAgentRegistryFromConfig(cfg) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported registry backend") +}