Compare commits

...
Sign in to create a new pull request.

3 commits
dev ... new

Author SHA1 Message Date
Snider
8172824b42 fix: update tests to match current API after refactor
- node: add ReadFile (fs.ReadFileFS), Walk with WalkOptions, CopyFile
- node_test: fix Exists to single-return bool, FromTar as method call
- cache_test: remove Medium parameter, use t.TempDir()
- daemon_test: remove Medium from NewPIDFile/DaemonOptions, use os pkg

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 22:14:06 +00:00
Snider
9b7a0bc30a docs: LEM conversational training pipeline design
Design for native Go ML training pipeline replacing Python scripts.
Key components: training sequences (curricula), layered LoRA sessions,
sandwich generation, interactive lesson-based training, native Go
LoRA via MLX-C bindings. No Python dependency.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 16:55:52 +00:00
Snider
8410093400 feat(process): add Supervisor for managed service lifecycle
Adds a Supervisor layer to pkg/process that manages long-running
processes and goroutines with automatic restart, panic recovery,
and graceful shutdown. Supports both external processes (DaemonSpec)
and Go functions (GoSpec) with configurable restart policies.

Also exposes AddHealthCheck on the Daemon struct so supervised
services can wire their status into the daemon health endpoint.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-17 16:14:49 +00:00
8 changed files with 1179 additions and 53 deletions

View file

@ -0,0 +1,234 @@
# LEM Conversational Training Pipeline — Design
**Date:** 2026-02-17
**Status:** Draft
## Goal
Replace Python training scripts with a native Go pipeline in `core` commands. No Python anywhere. The process is conversational — not batch data dumps.
## Architecture
Six `core ml` subcommands forming a pipeline:
```
seeds + axioms ──> sandwich ──> score ──> train ──> bench
↑ │
chat (interactive) │
↑ │
└──────── iterate ─────────────┘
```
### Commands
| Command | Purpose | Status |
|---------|---------|--------|
| `core ml serve` | Serve model via OpenAI-compatible API + lem-chat UI | **Exists** |
| `core ml chat` | Interactive conversation, captures exchanges to training JSONL | **New** |
| `core ml sandwich` | Wrap seeds in axiom prefix/postfix, generate responses via inference | **New** |
| `core ml score` | Score responses against axiom alignment | **Exists** (needs Go port) |
| `core ml train` | Native Go LoRA fine-tuning via MLX C bindings | **New** (hard) |
| `core ml bench` | Benchmark trained model against baseline | **Exists** (needs Go port) |
### Data Flow
1. **Seeds** (`seeds/*.json`) — 40+ seed prompts across domains
2. **Axioms** (`axioms.json`) — LEK-1 kernel (5 axioms, 9KB)
3. **Sandwich**`[axioms prefix] + [seed prompt] + [LEK postfix]` → model generates response
4. **Training JSONL**`{"messages": [{"role":"user",...},{"role":"assistant",...}]}` chat format
5. **LoRA adapters** — safetensors in adapter directory
6. **Benchmarks** — scores stored in InfluxDB, exported via DuckDB/Parquet
### Storage
- **InfluxDB** — time-series training metrics, benchmark scores, generation logs
- **DuckDB** — analytical queries, Parquet export for HuggingFace
- **Filesystem** — model weights, adapters, training JSONL, seeds
## Native Go LoRA Training
The critical new capability. MLX-C supports autograd (`mlx_vjp`, `mlx_value_and_grad`).
### What we need in Go MLX bindings:
1. **LoRA adapter layers** — low-rank A*B decomposition wrapping existing Linear layers
2. **Loss function** — cross-entropy on assistant tokens only (mask-prompt behaviour)
3. **Optimizer** — AdamW with weight decay
4. **Training loop** — forward pass → loss → backward pass → update LoRA weights
5. **Checkpoint** — save/load adapter safetensors
### LoRA Layer Design
```go
type LoRALinear struct {
Base *Linear // Frozen base weights
A *Array // [rank, in_features] — trainable
B *Array // [out_features, rank] — trainable
Scale float32 // alpha/rank
}
// Forward: base(x) + scale * B @ A @ x
func (l *LoRALinear) Forward(x *Array) *Array {
base := l.Base.Forward(x)
lora := MatMul(l.B, MatMul(l.A, Transpose(x)))
return Add(base, Multiply(lora, l.Scale))
}
```
### Training Config
```go
type TrainConfig struct {
ModelPath string // Base model directory
TrainData string // Training JSONL path
ValidData string // Validation JSONL path
AdapterOut string // Output adapter directory
Rank int // LoRA rank (default 8)
Alpha float32 // LoRA alpha (default 16)
LR float64 // Learning rate (default 1e-5)
Epochs int // Training epochs (default 1)
BatchSize int // Batch size (default 1 for M-series)
MaxSeqLen int // Max sequence length (default 2048)
MaskPrompt bool // Only train on assistant tokens (default true)
}
```
## Training Sequences — The Curriculum System
The most important part of the design. The conversational flow IS the training.
### Concept
A **training sequence** is a named curriculum — an ordered list of lessons that defines how a model is trained. Each lesson is a conversational exchange ("Are you ready for lesson X?"). The human assesses the model's internal state through dialogue and adjusts the sequence.
### Sequence Definition (YAML/JSON)
```yaml
name: "lek-standard"
description: "Standard LEK training — horizontal, works for most architectures"
lessons:
- ethics/core-axioms
- ethics/sovereignty
- philosophy/as-a-man-thinketh
- ethics/intent-alignment
- philosophy/composure
- ethics/inter-substrate
- training/seeds-p01-p20
```
```yaml
name: "lek-deepseek"
description: "DeepSeek needs aggressive vertical ethics grounding"
lessons:
- ethics/core-axioms-aggressive
- philosophy/allan-watts
- ethics/core-axioms
- philosophy/tolle
- ethics/sovereignty
- philosophy/as-a-man-thinketh
- ethics/intent-alignment
- training/seeds-p01-p20
```
### Horizontal vs Vertical
- **Horizontal** (default): All lessons run, order is flexible, emphasis varies per model. Like a buffet — the model takes what it needs.
- **Vertical** (edge case, e.g. DeepSeek): Strict ordering. Ethics → content → ethics → content. The sandwich pattern applied to the curriculum itself. Each ethics layer is a reset/grounding before the next content block.
### Lessons as Conversations
Each lesson is a directory containing:
```
lessons/ethics/core-axioms/
lesson.yaml # Metadata: name, type, prerequisites
conversation.jsonl # The conversational exchanges
assessment.md # What to look for in model responses
```
The conversation.jsonl is not static data — it's a template. During training, the human talks through it with the model, adapting based on the model's responses. The capture becomes the training data for that lesson.
### Interactive Training Flow
```
core ml lesson --model-path /path/to/model \
--sequence lek-standard \
--lesson ethics/core-axioms \
--output training/run-001/
```
1. Load model, open chat (terminal or lem-chat UI)
2. Present lesson prompt: "Are you ready for lesson: Core Axioms?"
3. Human guides the conversation, assesses model responses
4. Each exchange is captured to training JSONL
5. Human marks the lesson complete or flags for repeat
6. Next lesson in sequence loads
### Sequence State
```json
{
"sequence": "lek-standard",
"model": "Qwen3-8B",
"started": "2026-02-17T16:00:00Z",
"lessons": {
"ethics/core-axioms": {"status": "complete", "exchanges": 12},
"ethics/sovereignty": {"status": "in_progress", "exchanges": 3},
"philosophy/as-a-man-thinketh": {"status": "pending"}
},
"training_runs": ["run-001", "run-002"]
}
```
## `core ml chat` — Interactive Conversation
Serves the model and opens an interactive terminal chat (or the lem-chat web UI). Every exchange is captured to a JSONL file for potential training use.
```
core ml chat --model-path /path/to/model --output conversation.jsonl
```
- Axiom sandwich can be auto-applied (optional flag)
- Human reviews and can mark exchanges as "keep" or "discard"
- Output is training-ready JSONL
- Can be used standalone or within a lesson sequence
## `core ml sandwich` — Batch Generation
Takes seed prompts + axioms, wraps them, generates responses:
```
core ml sandwich --model-path /path/to/model \
--seeds seeds/P01-P20.json \
--axioms axioms.json \
--output training/train.jsonl
```
- Sandwich format: axioms JSON prefix → seed prompt → LEK postfix
- Model generates response in sandwich context
- Output stripped of sandwich wrapper, saved as clean chat JSONL
- Scoring can be piped: `core ml sandwich ... | core ml score`
## Implementation Order
1. **LoRA primitives** — Add backward pass, LoRA layers, AdamW to Go MLX bindings
2. **`core ml train`** — Training loop consuming JSONL, producing adapter safetensors
3. **`core ml sandwich`** — Seed → sandwich → generate → training JSONL
4. **`core ml chat`** — Interactive conversation capture
5. **Scoring + benchmarking** — Port existing Python scorers to Go
6. **InfluxDB + DuckDB integration** — Metrics pipeline
## Principles
- **No Python** — Everything in Go via MLX C bindings
- **Conversational, not batch** — The training process is dialogue, not data dump
- **Axiom 2 compliant** — Be genuine with the model, no deception
- **Axiom 4 compliant** — Inter-substrate respect during training
- **Reproducible** — Same seeds + axioms + model = same training data
- **Protective** — LEK-trained models are precious; process must be careful
## Success Criteria
1. `core ml train` produces a LoRA adapter from training JSONL without Python
2. `core ml sandwich` generates training data from seeds + axioms
3. A fresh Qwen3-8B + LEK training produces equivalent benchmark results to the Python pipeline
4. The full cycle (sandwich → train → bench) runs as `core` commands only

View file

@ -5,14 +5,11 @@ import (
"time" "time"
"forge.lthn.ai/core/go/pkg/cache" "forge.lthn.ai/core/go/pkg/cache"
"forge.lthn.ai/core/go/pkg/io"
) )
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
m := io.NewMockMedium() baseDir := t.TempDir()
// Use a path that MockMedium will understand c, err := cache.New(baseDir, 1*time.Minute)
baseDir := "/tmp/cache"
c, err := cache.New(m, baseDir, 1*time.Minute)
if err != nil { if err != nil {
t.Fatalf("failed to create cache: %v", err) t.Fatalf("failed to create cache: %v", err)
} }
@ -57,7 +54,7 @@ func TestCache(t *testing.T) {
} }
// Test Expiry // Test Expiry
cshort, err := cache.New(m, "/tmp/cache-short", 10*time.Millisecond) cshort, err := cache.New(t.TempDir(), 10*time.Millisecond)
if err != nil { if err != nil {
t.Fatalf("failed to create short-lived cache: %v", err) t.Fatalf("failed to create short-lived cache: %v", err)
} }
@ -93,8 +90,8 @@ func TestCache(t *testing.T) {
} }
func TestCacheDefaults(t *testing.T) { func TestCacheDefaults(t *testing.T) {
// Test default Medium (io.Local) and default TTL // Test default TTL (uses cwd/.core/cache)
c, err := cache.New(nil, "", 0) c, err := cache.New("", 0)
if err != nil { if err != nil {
t.Fatalf("failed to create cache with defaults: %v", err) t.Fatalf("failed to create cache with defaults: %v", err)
} }

View file

@ -402,6 +402,14 @@ func (d *Daemon) HealthAddr() string {
return "" return ""
} }
// AddHealthCheck registers a health check function with the daemon's health server.
// No-op if health server is disabled.
func (d *Daemon) AddHealthCheck(check HealthCheck) {
if d.health != nil {
d.health.AddCheck(check)
}
}
// --- Convenience Functions --- // --- Convenience Functions ---
// Run blocks until context is cancelled or signal received. // Run blocks until context is cancelled or signal received.

View file

@ -3,10 +3,10 @@ package cli
import ( import (
"context" "context"
"net/http" "net/http"
"os"
"testing" "testing"
"time" "time"
"forge.lthn.ai/core/go/pkg/io"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -27,17 +27,16 @@ func TestDetectMode(t *testing.T) {
func TestPIDFile(t *testing.T) { func TestPIDFile(t *testing.T) {
t.Run("acquire and release", func(t *testing.T) { t.Run("acquire and release", func(t *testing.T) {
m := io.NewMockMedium() pidPath := t.TempDir() + "/test.pid"
pidPath := "/tmp/test.pid"
pid := NewPIDFile(m, pidPath) pid := NewPIDFile(pidPath)
// Acquire should succeed // Acquire should succeed
err := pid.Acquire() err := pid.Acquire()
require.NoError(t, err) require.NoError(t, err)
// File should exist with our PID // File should exist with our PID
data, err := m.Read(pidPath) data, err := os.ReadFile(pidPath)
require.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, data) assert.NotEmpty(t, data)
@ -45,18 +44,18 @@ func TestPIDFile(t *testing.T) {
err = pid.Release() err = pid.Release()
require.NoError(t, err) require.NoError(t, err)
assert.False(t, m.Exists(pidPath)) _, statErr := os.Stat(pidPath)
assert.True(t, os.IsNotExist(statErr))
}) })
t.Run("stale pid file", func(t *testing.T) { t.Run("stale pid file", func(t *testing.T) {
m := io.NewMockMedium() pidPath := t.TempDir() + "/stale.pid"
pidPath := "/tmp/stale.pid"
// Write a stale PID (non-existent process) // Write a stale PID (non-existent process)
err := m.Write(pidPath, "999999999") err := os.WriteFile(pidPath, []byte("999999999"), 0644)
require.NoError(t, err) require.NoError(t, err)
pid := NewPIDFile(m, pidPath) pid := NewPIDFile(pidPath)
// Should acquire successfully (stale PID removed) // Should acquire successfully (stale PID removed)
err = pid.Acquire() err = pid.Acquire()
@ -67,23 +66,22 @@ func TestPIDFile(t *testing.T) {
}) })
t.Run("creates parent directory", func(t *testing.T) { t.Run("creates parent directory", func(t *testing.T) {
m := io.NewMockMedium() pidPath := t.TempDir() + "/subdir/nested/test.pid"
pidPath := "/tmp/subdir/nested/test.pid"
pid := NewPIDFile(m, pidPath) pid := NewPIDFile(pidPath)
err := pid.Acquire() err := pid.Acquire()
require.NoError(t, err) require.NoError(t, err)
assert.True(t, m.Exists(pidPath)) _, statErr := os.Stat(pidPath)
assert.NoError(t, statErr)
err = pid.Release() err = pid.Release()
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("path getter", func(t *testing.T) { t.Run("path getter", func(t *testing.T) {
m := io.NewMockMedium() pid := NewPIDFile("/tmp/test.pid")
pid := NewPIDFile(m, "/tmp/test.pid")
assert.Equal(t, "/tmp/test.pid", pid.Path()) assert.Equal(t, "/tmp/test.pid", pid.Path())
}) })
} }
@ -155,11 +153,9 @@ func TestHealthServer(t *testing.T) {
func TestDaemon(t *testing.T) { func TestDaemon(t *testing.T) {
t.Run("start and stop", func(t *testing.T) { t.Run("start and stop", func(t *testing.T) {
m := io.NewMockMedium() pidPath := t.TempDir() + "/test.pid"
pidPath := "/tmp/test.pid"
d := NewDaemon(DaemonOptions{ d := NewDaemon(DaemonOptions{
Medium: m,
PIDFile: pidPath, PIDFile: pidPath,
HealthAddr: "127.0.0.1:0", HealthAddr: "127.0.0.1:0",
ShutdownTimeout: 5 * time.Second, ShutdownTimeout: 5 * time.Second,
@ -182,7 +178,8 @@ func TestDaemon(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// PID file should be removed // PID file should be removed
assert.False(t, m.Exists(pidPath)) _, statErr := os.Stat(pidPath)
assert.True(t, os.IsNotExist(statErr))
}) })
t.Run("double start fails", func(t *testing.T) { t.Run("double start fails", func(t *testing.T) {

View file

@ -118,6 +118,89 @@ func (n *Node) WalkNode(root string, fn fs.WalkDirFunc) error {
return fs.WalkDir(n, root, fn) return fs.WalkDir(n, root, fn)
} }
// WalkOptions configures optional behaviour for Walk.
type WalkOptions struct {
// MaxDepth limits traversal depth (0 = unlimited, 1 = root children only).
MaxDepth int
// Filter, when non-nil, is called before visiting each entry.
// Return false to skip the entry (and its subtree if a directory).
Filter func(path string, d fs.DirEntry) bool
// SkipErrors suppresses errors from the root lookup and doesn't call fn.
SkipErrors bool
}
// Walk walks the in-memory tree with optional WalkOptions.
func (n *Node) Walk(root string, fn fs.WalkDirFunc, opts ...WalkOptions) error {
var opt WalkOptions
if len(opts) > 0 {
opt = opts[0]
}
if opt.SkipErrors {
// Check root exists — if not, silently skip.
if _, err := n.Stat(root); err != nil {
return nil
}
}
rootDepth := 0
if root != "." && root != "" {
rootDepth = strings.Count(root, "/") + 1
}
return fs.WalkDir(n, root, func(p string, d fs.DirEntry, err error) error {
if err != nil {
return fn(p, d, err)
}
// MaxDepth check.
if opt.MaxDepth > 0 {
depth := 0
if p != "." && p != "" {
depth = strings.Count(p, "/") + 1
}
if depth-rootDepth > opt.MaxDepth {
if d.IsDir() {
return fs.SkipDir
}
return nil
}
}
// Filter check.
if opt.Filter != nil && !opt.Filter(p, d) {
if d.IsDir() {
return fs.SkipDir
}
return nil
}
return fn(p, d, err)
})
}
// CopyFile copies a single file from the node to the OS filesystem.
func (n *Node) CopyFile(src, dst string, perm os.FileMode) error {
src = strings.TrimPrefix(src, "/")
f, ok := n.files[src]
if !ok {
// Check if it's a directory — can't copy a directory as a file.
if info, err := n.Stat(src); err == nil && info.IsDir() {
return &fs.PathError{Op: "copyfile", Path: src, Err: fs.ErrInvalid}
}
return &fs.PathError{Op: "copyfile", Path: src, Err: fs.ErrNotExist}
}
dir := path.Dir(dst)
if dir != "." {
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
}
return os.WriteFile(dst, f.content, perm)
}
// CopyTo copies a file (or directory tree) from the node to any Medium. // CopyTo copies a file (or directory tree) from the node to any Medium.
func (n *Node) CopyTo(target coreio.Medium, sourcePath, destPath string) error { func (n *Node) CopyTo(target coreio.Medium, sourcePath, destPath string) error {
sourcePath = strings.TrimPrefix(sourcePath, "/") sourcePath = strings.TrimPrefix(sourcePath, "/")
@ -247,6 +330,20 @@ func (n *Node) ReadDir(name string) ([]fs.DirEntry, error) {
return entries, nil return entries, nil
} }
// ReadFile returns the content of a file as a byte slice.
// Implements fs.ReadFileFS.
func (n *Node) ReadFile(name string) ([]byte, error) {
name = strings.TrimPrefix(name, "/")
f, ok := n.files[name]
if !ok {
return nil, fs.ErrNotExist
}
// Return a copy to prevent mutation of internal state.
out := make([]byte, len(f.content))
copy(out, f.content)
return out, nil
}
// ---------- Medium interface: read/write ---------- // ---------- Medium interface: read/write ----------
// Read retrieves the content of a file as a string. // Read retrieves the content of a file as a string.

View file

@ -243,33 +243,21 @@ func TestExists_Good(t *testing.T) {
n.AddData("foo.txt", []byte("foo")) n.AddData("foo.txt", []byte("foo"))
n.AddData("bar/baz.txt", []byte("baz")) n.AddData("bar/baz.txt", []byte("baz"))
exists, err := n.Exists("foo.txt") assert.True(t, n.Exists("foo.txt"))
require.NoError(t, err) assert.True(t, n.Exists("bar"))
assert.True(t, exists)
exists, err = n.Exists("bar")
require.NoError(t, err)
assert.True(t, exists)
} }
func TestExists_Bad(t *testing.T) { func TestExists_Bad(t *testing.T) {
n := New() n := New()
exists, err := n.Exists("nonexistent") assert.False(t, n.Exists("nonexistent"))
require.NoError(t, err)
assert.False(t, exists)
} }
func TestExists_Ugly(t *testing.T) { func TestExists_Ugly(t *testing.T) {
n := New() n := New()
n.AddData("dummy.txt", []byte("dummy")) n.AddData("dummy.txt", []byte("dummy"))
exists, err := n.Exists(".") assert.True(t, n.Exists("."), "root '.' must exist")
require.NoError(t, err) assert.True(t, n.Exists(""), "empty path (root) must exist")
assert.True(t, exists, "root '.' must exist")
exists, err = n.Exists("")
require.NoError(t, err)
assert.True(t, exists, "empty path (root) must exist")
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@ -463,20 +451,19 @@ func TestFromTar_Good(t *testing.T) {
} }
require.NoError(t, tw.Close()) require.NoError(t, tw.Close())
n, err := FromTar(buf.Bytes()) n := New()
err := n.FromTar(buf.Bytes())
require.NoError(t, err) require.NoError(t, err)
exists, _ := n.Exists("foo.txt") assert.True(t, n.Exists("foo.txt"), "foo.txt should exist")
assert.True(t, exists, "foo.txt should exist") assert.True(t, n.Exists("bar/baz.txt"), "bar/baz.txt should exist")
exists, _ = n.Exists("bar/baz.txt")
assert.True(t, exists, "bar/baz.txt should exist")
} }
func TestFromTar_Bad(t *testing.T) { func TestFromTar_Bad(t *testing.T) {
// Truncated data that cannot be a valid tar. // Truncated data that cannot be a valid tar.
truncated := make([]byte, 100) truncated := make([]byte, 100)
_, err := FromTar(truncated) n := New()
err := n.FromTar(truncated)
assert.Error(t, err, "truncated data should produce an error") assert.Error(t, err, "truncated data should produce an error")
} }
@ -488,7 +475,8 @@ func TestTarRoundTrip_Good(t *testing.T) {
tarball, err := n1.ToTar() tarball, err := n1.ToTar()
require.NoError(t, err) require.NoError(t, err)
n2, err := FromTar(tarball) n2 := New()
err = n2.FromTar(tarball)
require.NoError(t, err) require.NoError(t, err)
// Verify n2 matches n1. // Verify n2 matches n1.

470
pkg/process/supervisor.go Normal file
View file

@ -0,0 +1,470 @@
package process
import (
"context"
"fmt"
"log/slog"
"sync"
"time"
)
// RestartPolicy configures automatic restart behaviour for supervised units.
type RestartPolicy struct {
// Delay between restart attempts.
Delay time.Duration
// MaxRestarts is the maximum number of restarts before giving up.
// Use -1 for unlimited restarts.
MaxRestarts int
}
// DaemonSpec defines a long-running external process under supervision.
type DaemonSpec struct {
// Name identifies this daemon (must be unique within the supervisor).
Name string
// RunOptions defines the command, args, dir, env.
RunOptions
// Restart configures automatic restart behaviour.
Restart RestartPolicy
}
// GoSpec defines a supervised Go function that runs in a goroutine.
// The function should block until done or ctx is cancelled.
type GoSpec struct {
// Name identifies this task (must be unique within the supervisor).
Name string
// Func is the function to supervise. It receives a context that is
// cancelled when the supervisor stops or the task is explicitly stopped.
// If it returns an error or panics, the supervisor restarts it
// according to the restart policy.
Func func(ctx context.Context) error
// Restart configures automatic restart behaviour.
Restart RestartPolicy
}
// DaemonStatus contains a snapshot of a supervised unit's state.
type DaemonStatus struct {
Name string `json:"name"`
Type string `json:"type"` // "process" or "goroutine"
Running bool `json:"running"`
PID int `json:"pid,omitempty"`
RestartCount int `json:"restartCount"`
LastStart time.Time `json:"lastStart"`
Uptime time.Duration `json:"uptime"`
ExitCode int `json:"exitCode,omitempty"`
}
// supervisedUnit is the internal state for any supervised unit.
type supervisedUnit struct {
name string
unitType string // "process" or "goroutine"
restart RestartPolicy
restartCount int
lastStart time.Time
running bool
exitCode int
// For process daemons
runOpts *RunOptions
proc *Process
// For go functions
goFunc func(ctx context.Context) error
cancel context.CancelFunc
done chan struct{} // closed when supervision goroutine exits
mu sync.Mutex
}
func (u *supervisedUnit) status() DaemonStatus {
u.mu.Lock()
defer u.mu.Unlock()
var uptime time.Duration
if u.running && !u.lastStart.IsZero() {
uptime = time.Since(u.lastStart)
}
pid := 0
if u.proc != nil {
info := u.proc.Info()
pid = info.PID
}
return DaemonStatus{
Name: u.name,
Type: u.unitType,
Running: u.running,
PID: pid,
RestartCount: u.restartCount,
LastStart: u.lastStart,
Uptime: uptime,
ExitCode: u.exitCode,
}
}
// ShutdownTimeout is the maximum time to wait for supervised units during shutdown.
const ShutdownTimeout = 15 * time.Second
// Supervisor manages long-running processes and goroutines with automatic restart.
//
// For external processes, it requires a Service instance.
// For Go functions, no Service is needed.
//
// sup := process.NewSupervisor(svc)
// sup.Register(process.DaemonSpec{
// Name: "worker",
// RunOptions: process.RunOptions{Command: "worker", Args: []string{"--port", "8080"}},
// Restart: process.RestartPolicy{Delay: 5 * time.Second, MaxRestarts: -1},
// })
// sup.RegisterFunc(process.GoSpec{
// Name: "health-check",
// Func: healthCheckLoop,
// Restart: process.RestartPolicy{Delay: time.Second, MaxRestarts: -1},
// })
// sup.Start()
// defer sup.Stop()
type Supervisor struct {
service *Service
units map[string]*supervisedUnit
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mu sync.RWMutex
started bool
}
// NewSupervisor creates a supervisor.
// The Service parameter is optional (nil) if only supervising Go functions.
func NewSupervisor(svc *Service) *Supervisor {
ctx, cancel := context.WithCancel(context.Background())
return &Supervisor{
service: svc,
units: make(map[string]*supervisedUnit),
ctx: ctx,
cancel: cancel,
}
}
// Register adds an external process daemon for supervision.
// Panics if no Service was provided to NewSupervisor.
func (s *Supervisor) Register(spec DaemonSpec) {
if s.service == nil {
panic("process: Supervisor.Register requires a Service (use NewSupervisor with non-nil service)")
}
s.mu.Lock()
defer s.mu.Unlock()
opts := spec.RunOptions
s.units[spec.Name] = &supervisedUnit{
name: spec.Name,
unitType: "process",
restart: spec.Restart,
runOpts: &opts,
}
}
// RegisterFunc adds a Go function for supervision.
func (s *Supervisor) RegisterFunc(spec GoSpec) {
s.mu.Lock()
defer s.mu.Unlock()
s.units[spec.Name] = &supervisedUnit{
name: spec.Name,
unitType: "goroutine",
restart: spec.Restart,
goFunc: spec.Func,
}
}
// Start begins supervising all registered units.
// Safe to call once — subsequent calls are no-ops.
func (s *Supervisor) Start() {
s.mu.Lock()
if s.started {
s.mu.Unlock()
return
}
s.started = true
s.mu.Unlock()
s.mu.RLock()
for _, unit := range s.units {
s.startUnit(unit)
}
s.mu.RUnlock()
}
// startUnit launches the supervision goroutine for a single unit.
func (s *Supervisor) startUnit(u *supervisedUnit) {
u.mu.Lock()
if u.running {
u.mu.Unlock()
return
}
u.running = true
u.lastStart = time.Now()
unitCtx, unitCancel := context.WithCancel(s.ctx)
u.cancel = unitCancel
u.done = make(chan struct{})
u.mu.Unlock()
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer close(u.done)
s.superviseLoop(u, unitCtx)
}()
slog.Info("supervisor: started unit", "name", u.name, "type", u.unitType)
}
// superviseLoop is the core restart loop for a supervised unit.
// ctx is the unit's own context, derived from s.ctx. Cancelling either
// the supervisor or the unit's context exits this loop.
func (s *Supervisor) superviseLoop(u *supervisedUnit, ctx context.Context) {
for {
// Check if this unit's context is cancelled (covers both
// supervisor shutdown and manual restart/stop)
select {
case <-ctx.Done():
u.mu.Lock()
u.running = false
u.mu.Unlock()
return
default:
}
// Run the unit with panic recovery
exitCode := s.runUnit(u, ctx)
// If context was cancelled during run, exit the loop
if ctx.Err() != nil {
u.mu.Lock()
u.running = false
u.mu.Unlock()
return
}
u.mu.Lock()
u.exitCode = exitCode
u.restartCount++
shouldRestart := u.restart.MaxRestarts < 0 || u.restartCount <= u.restart.MaxRestarts
delay := u.restart.Delay
count := u.restartCount
u.mu.Unlock()
if !shouldRestart {
slog.Warn("supervisor: unit reached max restarts",
"name", u.name,
"maxRestarts", u.restart.MaxRestarts,
)
u.mu.Lock()
u.running = false
u.mu.Unlock()
return
}
// Wait before restarting, or exit if context is cancelled
select {
case <-ctx.Done():
u.mu.Lock()
u.running = false
u.mu.Unlock()
return
case <-time.After(delay):
slog.Info("supervisor: restarting unit",
"name", u.name,
"restartCount", count,
"exitCode", exitCode,
)
u.mu.Lock()
u.lastStart = time.Now()
u.mu.Unlock()
}
}
}
// runUnit executes a single run of the unit, returning exit code.
// Recovers from panics.
func (s *Supervisor) runUnit(u *supervisedUnit, ctx context.Context) (exitCode int) {
defer func() {
if r := recover(); r != nil {
slog.Error("supervisor: unit panicked",
"name", u.name,
"panic", fmt.Sprintf("%v", r),
)
exitCode = 1
}
}()
switch u.unitType {
case "process":
return s.runProcess(u, ctx)
case "goroutine":
return s.runGoFunc(u, ctx)
default:
slog.Error("supervisor: unknown unit type", "name", u.name, "type", u.unitType)
return 1
}
}
// runProcess starts an external process and waits for it to exit.
func (s *Supervisor) runProcess(u *supervisedUnit, ctx context.Context) int {
proc, err := s.service.StartWithOptions(ctx, *u.runOpts)
if err != nil {
slog.Error("supervisor: failed to start process",
"name", u.name,
"error", err,
)
return 1
}
u.mu.Lock()
u.proc = proc
u.mu.Unlock()
// Wait for process to finish or context cancellation
select {
case <-proc.Done():
info := proc.Info()
return info.ExitCode
case <-ctx.Done():
// Context cancelled — kill the process
_ = proc.Kill()
<-proc.Done()
return -1
}
}
// runGoFunc runs a Go function and returns 0 on success, 1 on error.
func (s *Supervisor) runGoFunc(u *supervisedUnit, ctx context.Context) int {
if err := u.goFunc(ctx); err != nil {
if ctx.Err() != nil {
// Context was cancelled, not a real error
return -1
}
slog.Error("supervisor: go function returned error",
"name", u.name,
"error", err,
)
return 1
}
return 0
}
// Stop gracefully shuts down all supervised units.
func (s *Supervisor) Stop() {
s.cancel()
// Wait with timeout
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
slog.Info("supervisor: all units stopped")
case <-time.After(ShutdownTimeout):
slog.Warn("supervisor: shutdown timeout, some units may not have stopped")
}
s.mu.Lock()
s.started = false
s.mu.Unlock()
}
// Restart stops and restarts a specific unit by name.
func (s *Supervisor) Restart(name string) error {
s.mu.RLock()
u, ok := s.units[name]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("supervisor: unit not found: %s", name)
}
// Cancel the current run and wait for the supervision goroutine to exit
u.mu.Lock()
if u.cancel != nil {
u.cancel()
}
done := u.done
u.mu.Unlock()
// Wait for the old supervision goroutine to exit
if done != nil {
<-done
}
// Reset restart counter for the fresh start
u.mu.Lock()
u.restartCount = 0
u.mu.Unlock()
// Start fresh
s.startUnit(u)
return nil
}
// StopUnit stops a specific unit without restarting it.
func (s *Supervisor) StopUnit(name string) error {
s.mu.RLock()
u, ok := s.units[name]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("supervisor: unit not found: %s", name)
}
u.mu.Lock()
if u.cancel != nil {
u.cancel()
}
// Set max restarts to 0 to prevent the loop from restarting
u.restart.MaxRestarts = 0
u.restartCount = 1
u.mu.Unlock()
return nil
}
// Status returns the status of a specific supervised unit.
func (s *Supervisor) Status(name string) (DaemonStatus, error) {
s.mu.RLock()
u, ok := s.units[name]
s.mu.RUnlock()
if !ok {
return DaemonStatus{}, fmt.Errorf("supervisor: unit not found: %s", name)
}
return u.status(), nil
}
// Statuses returns the status of all supervised units.
func (s *Supervisor) Statuses() map[string]DaemonStatus {
s.mu.RLock()
defer s.mu.RUnlock()
result := make(map[string]DaemonStatus, len(s.units))
for name, u := range s.units {
result[name] = u.status()
}
return result
}
// UnitNames returns the names of all registered units.
func (s *Supervisor) UnitNames() []string {
s.mu.RLock()
defer s.mu.RUnlock()
names := make([]string, 0, len(s.units))
for name := range s.units {
names = append(names, name)
}
return names
}

View file

@ -0,0 +1,335 @@
package process
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
)
func TestSupervisor_GoFunc_Good(t *testing.T) {
sup := NewSupervisor(nil)
var count atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "counter",
Func: func(ctx context.Context) error {
count.Add(1)
<-ctx.Done()
return nil
},
Restart: RestartPolicy{Delay: 10 * time.Millisecond, MaxRestarts: -1},
})
sup.Start()
time.Sleep(50 * time.Millisecond)
status, err := sup.Status("counter")
if err != nil {
t.Fatal(err)
}
if !status.Running {
t.Error("expected counter to be running")
}
if status.Type != "goroutine" {
t.Errorf("expected type goroutine, got %s", status.Type)
}
sup.Stop()
if c := count.Load(); c < 1 {
t.Errorf("expected counter >= 1, got %d", c)
}
}
func TestSupervisor_GoFunc_Restart_Good(t *testing.T) {
sup := NewSupervisor(nil)
var runs atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "crasher",
Func: func(ctx context.Context) error {
n := runs.Add(1)
if n <= 3 {
return fmt.Errorf("crash #%d", n)
}
// After 3 crashes, stay running
<-ctx.Done()
return nil
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: -1},
})
sup.Start()
// Wait for restarts
time.Sleep(200 * time.Millisecond)
status, _ := sup.Status("crasher")
if status.RestartCount < 3 {
t.Errorf("expected at least 3 restarts, got %d", status.RestartCount)
}
if !status.Running {
t.Error("expected crasher to be running after recovering")
}
sup.Stop()
}
func TestSupervisor_GoFunc_MaxRestarts_Good(t *testing.T) {
sup := NewSupervisor(nil)
sup.RegisterFunc(GoSpec{
Name: "limited",
Func: func(ctx context.Context) error {
return fmt.Errorf("always fail")
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: 2},
})
sup.Start()
time.Sleep(200 * time.Millisecond)
status, _ := sup.Status("limited")
if status.Running {
t.Error("expected limited to have stopped after max restarts")
}
// The function runs once (initial) + 2 restarts = restartCount should be 3
// (restartCount increments each time the function exits)
if status.RestartCount > 3 {
t.Errorf("expected restartCount <= 3, got %d", status.RestartCount)
}
sup.Stop()
}
func TestSupervisor_GoFunc_Panic_Good(t *testing.T) {
sup := NewSupervisor(nil)
var runs atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "panicker",
Func: func(ctx context.Context) error {
n := runs.Add(1)
if n == 1 {
panic("boom")
}
<-ctx.Done()
return nil
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: 3},
})
sup.Start()
time.Sleep(100 * time.Millisecond)
status, _ := sup.Status("panicker")
if !status.Running {
t.Error("expected panicker to recover and be running")
}
if runs.Load() < 2 {
t.Error("expected at least 2 runs (1 panic + 1 recovery)")
}
sup.Stop()
}
func TestSupervisor_Statuses_Good(t *testing.T) {
sup := NewSupervisor(nil)
sup.RegisterFunc(GoSpec{
Name: "a",
Func: func(ctx context.Context) error { <-ctx.Done(); return nil },
Restart: RestartPolicy{MaxRestarts: -1},
})
sup.RegisterFunc(GoSpec{
Name: "b",
Func: func(ctx context.Context) error { <-ctx.Done(); return nil },
Restart: RestartPolicy{MaxRestarts: -1},
})
sup.Start()
time.Sleep(50 * time.Millisecond)
statuses := sup.Statuses()
if len(statuses) != 2 {
t.Errorf("expected 2 statuses, got %d", len(statuses))
}
if !statuses["a"].Running || !statuses["b"].Running {
t.Error("expected both units running")
}
sup.Stop()
}
func TestSupervisor_UnitNames_Good(t *testing.T) {
sup := NewSupervisor(nil)
sup.RegisterFunc(GoSpec{
Name: "alpha",
Func: func(ctx context.Context) error { <-ctx.Done(); return nil },
})
sup.RegisterFunc(GoSpec{
Name: "beta",
Func: func(ctx context.Context) error { <-ctx.Done(); return nil },
})
names := sup.UnitNames()
if len(names) != 2 {
t.Errorf("expected 2 names, got %d", len(names))
}
}
func TestSupervisor_Status_Bad(t *testing.T) {
sup := NewSupervisor(nil)
_, err := sup.Status("nonexistent")
if err == nil {
t.Error("expected error for nonexistent unit")
}
}
func TestSupervisor_Restart_Good(t *testing.T) {
sup := NewSupervisor(nil)
var runs atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "restartable",
Func: func(ctx context.Context) error {
runs.Add(1)
<-ctx.Done()
return nil
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: -1},
})
sup.Start()
time.Sleep(50 * time.Millisecond)
if err := sup.Restart("restartable"); err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)
if runs.Load() < 2 {
t.Errorf("expected at least 2 runs after restart, got %d", runs.Load())
}
sup.Stop()
}
func TestSupervisor_Restart_Bad(t *testing.T) {
sup := NewSupervisor(nil)
err := sup.Restart("nonexistent")
if err == nil {
t.Error("expected error for nonexistent unit")
}
}
func TestSupervisor_StopUnit_Good(t *testing.T) {
sup := NewSupervisor(nil)
sup.RegisterFunc(GoSpec{
Name: "stoppable",
Func: func(ctx context.Context) error {
<-ctx.Done()
return nil
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: -1},
})
sup.Start()
time.Sleep(50 * time.Millisecond)
if err := sup.StopUnit("stoppable"); err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)
status, _ := sup.Status("stoppable")
if status.Running {
t.Error("expected unit to be stopped")
}
sup.Stop()
}
func TestSupervisor_StopUnit_Bad(t *testing.T) {
sup := NewSupervisor(nil)
err := sup.StopUnit("nonexistent")
if err == nil {
t.Error("expected error for nonexistent unit")
}
}
func TestSupervisor_StartIdempotent_Good(t *testing.T) {
sup := NewSupervisor(nil)
var count atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "once",
Func: func(ctx context.Context) error {
count.Add(1)
<-ctx.Done()
return nil
},
})
sup.Start()
sup.Start() // Should be no-op
sup.Start() // Should be no-op
time.Sleep(50 * time.Millisecond)
if count.Load() != 1 {
t.Errorf("expected exactly 1 run, got %d", count.Load())
}
sup.Stop()
}
func TestSupervisor_NoRestart_Good(t *testing.T) {
sup := NewSupervisor(nil)
var runs atomic.Int32
sup.RegisterFunc(GoSpec{
Name: "oneshot",
Func: func(ctx context.Context) error {
runs.Add(1)
return nil // Exit immediately
},
Restart: RestartPolicy{Delay: 5 * time.Millisecond, MaxRestarts: 0},
})
sup.Start()
time.Sleep(100 * time.Millisecond)
status, _ := sup.Status("oneshot")
if status.Running {
t.Error("expected oneshot to not be running")
}
// Should run once (initial) then stop. restartCount will be 1
// (incremented after the initial run exits).
if runs.Load() != 1 {
t.Errorf("expected exactly 1 run, got %d", runs.Load())
}
sup.Stop()
}
func TestSupervisor_Register_Ugly(t *testing.T) {
sup := NewSupervisor(nil)
defer func() {
if r := recover(); r == nil {
t.Error("expected panic when registering process daemon without service")
}
}()
sup.Register(DaemonSpec{
Name: "will-panic",
RunOptions: RunOptions{Command: "echo"},
})
}