1
0
Fork 0
forked from lthn/LEM
LEM/pkg/lem/distill.go
Snider 3606ff994b fix: memory, error handling, and signal improvements across pkg/lem
- Stream parquet export rows instead of unbounded memory allocation
- Replace QueryGoldenSet/QueryExpansionPrompts with iter.Seq2 iterators
- Remove legacy runtime.GC() calls from distill (go-mlx handles cleanup)
- Replace log.Fatalf with error return in tier_score.go
- Add SIGINT/SIGTERM signal handling to agent and worker daemon loops
- Add error checks for unchecked db.conn.Exec in import.go and tier_score.go
- Update tests for iterator-based database methods

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 04:46:51 +00:00

421 lines
12 KiB
Go

package lem
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/go-i18n/reversal"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
// DistillProbe is a single input prompt for distillation.
type DistillProbe struct {
ID string `json:"id"`
Domain string `json:"domain,omitempty"`
Prompt string `json:"prompt"`
Source string `json:"-"`
}
// distillCandidate holds a single generation attempt with its scores.
type distillCandidate struct {
Response string
Grammar GrammarScore
Delta DeltaScore
Elapsed time.Duration
}
// DistillOpts holds configuration for the distill command.
type DistillOpts struct {
Model string // Model config path (relative to .core/ai/models/)
Probes string // Probe set name from probes.yaml, or path to JSON file
Output string // Output JSONL path (defaults to model training dir)
Lesson int // Lesson number to append to (defaults to probe set phase)
MinScore float64 // Min grammar composite (0 = use ai.yaml default)
Runs int // Generations per probe (0 = use ai.yaml default)
DryRun bool // Show plan and exit without generating
Root string // Project root (for .core/ai/ config)
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
}
// RunDistill is the CLI entry point for the distill command.
// Generates responses via native Metal inference, scores with go-i18n/reversal,
// writes passing examples as training JSONL.
func RunDistill(cfg DistillOpts) error {
// Load configs.
aiCfg, err := LoadAIConfig(cfg.Root)
if err != nil {
return fmt.Errorf("load ai config: %w", err)
}
// Apply config defaults for model and probes.
model := cfg.Model
if model == "" {
model = aiCfg.Distill.Model
}
if model == "" {
return fmt.Errorf("no model specified (use --model or set distill.model in ai.yaml)")
}
modelCfg, err := LoadModelConfig(cfg.Root, model)
if err != nil {
return fmt.Errorf("load model config: %w", err)
}
genCfg := MergeGenerate(aiCfg.Generate, modelCfg.Generate)
// Apply flag overrides.
minScore := cfg.MinScore
if minScore == 0 {
minScore = aiCfg.Scorer.MinScore
}
runs := cfg.Runs
if runs == 0 {
runs = aiCfg.Distill.Runs
}
cacheLimitGB := aiCfg.Distill.CacheLimit
if cfg.CacheLimit > 0 {
cacheLimitGB = cfg.CacheLimit
}
memLimitGB := aiCfg.Distill.MemoryLimit
if cfg.MemLimit > 0 {
memLimitGB = cfg.MemLimit
}
// Load probes.
probeSet := cfg.Probes
if probeSet == "" {
probeSet = aiCfg.Distill.Probes
}
if probeSet == "" {
return fmt.Errorf("no probe set specified (use --probes or set distill.probes in ai.yaml)")
}
probes, phase, err := loadDistillProbes(cfg.Root, probeSet)
if err != nil {
return fmt.Errorf("load probes: %w", err)
}
log.Printf("loaded %d probes", len(probes))
// Determine output path.
outputPath := cfg.Output
if outputPath == "" {
lesson := cfg.Lesson
if lesson < 0 {
lesson = phase
}
lessonFile, ok := modelCfg.Lessons[lesson]
if !ok {
lessonFile = fmt.Sprintf("lesson-%d.jsonl", lesson)
}
outputPath = filepath.Join(modelCfg.Training, lessonFile)
}
// Load kernel and signature — only if the model config specifies them.
// LEM models have axioms in weights; loading the kernel violates A4.
var kernel []byte
var sig string
if modelCfg.Kernel != "" {
kernel, err = os.ReadFile(modelCfg.Kernel)
if err != nil {
return fmt.Errorf("read kernel: %w", err)
}
log.Printf("kernel: %d chars from %s (sandwich output)", len(kernel), modelCfg.Kernel)
if modelCfg.Signature != "" {
sigBytes, err := os.ReadFile(modelCfg.Signature)
if err != nil {
return fmt.Errorf("read signature: %w", err)
}
sig = strings.TrimSpace(string(sigBytes))
}
} else {
log.Printf("no kernel configured — bare output (LEM model)")
}
// Dry run.
if cfg.DryRun {
fmt.Printf("Model: %s (%s)\n", modelCfg.Name, modelCfg.Paths.Base)
fmt.Printf("Backend: %s\n", aiCfg.Backend)
fmt.Printf("Probes: %d\n", len(probes))
fmt.Printf("Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore)
fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n",
genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP)
fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB)
fmt.Printf("Output: %s\n", outputPath)
fmt.Println()
for i, p := range probes {
if i >= 10 {
fmt.Printf(" ... and %d more\n", len(probes)-10)
break
}
prompt := p.Prompt
if len(prompt) > 80 {
prompt = prompt[:80] + "..."
}
fmt.Printf(" %s: %s\n", p.ID, prompt)
}
return nil
}
// Set Metal memory limits before loading model.
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal cache limit: %dGB", cacheLimitGB)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal memory limit: %dGB", memLimitGB)
}
// Load model via go-ml Backend (wraps go-inference with memory management).
log.Printf("loading model: %s", modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
defer backend.Close()
log.Printf("model loaded via %s backend", backend.Name())
// Build generation options from merged config.
genOpts := ml.GenOpts{
MaxTokens: genCfg.MaxTokens,
Temperature: genCfg.Temperature,
TopP: genCfg.TopP,
TopK: genCfg.TopK,
RepeatPenalty: genCfg.RepeatPenalty,
}
// Initialise grammar scorer.
tok := reversal.NewTokeniser()
// Open output for append.
out, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("open output: %w", err)
}
defer out.Close()
kept := 0
skipped := 0
deduped := 0
totalStart := time.Now()
ctx := context.Background()
kernelStr := strings.TrimSpace(string(kernel))
// Running duplicate index for grammar-profile deduplication.
var dedupIdx *ScoreIndex
for i, probe := range probes {
var best *distillCandidate
// Build output prompt: sandwich if kernel configured, bare if LEM model.
var outputPrompt string
if kernelStr != "" {
outputPrompt = kernelStr + "\n\n" + probe.Prompt
if sig != "" {
outputPrompt += "\n\n" + sig
}
} else {
outputPrompt = probe.Prompt
}
for run := range runs {
fmt.Fprintf(os.Stderr, " [%d/%d] %s run %d/%d",
i+1, len(probes), probe.ID, run+1, runs)
// Inference uses bare probe — the model generates from its weights.
// Sandwich wrapping is only for the training output format.
messages := []ml.Message{
{Role: "user", Content: probe.Prompt},
}
// Generate via go-ml Backend (memory-managed Metal inference).
start := time.Now()
result, err := backend.Chat(ctx, messages, genOpts)
if err != nil {
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
continue
}
response := result.Text
elapsed := time.Since(start)
// Quick reject: empty/degenerate.
if len(strings.TrimSpace(response)) < aiCfg.Distill.MinChars {
fmt.Fprintf(os.Stderr, " → %d chars, EMPTY, %.1fs\n", len(response), elapsed.Seconds())
continue
}
// Score with go-i18n/reversal.
grammar := ScoreResponse(tok, response)
delta := ComputeDelta(tok, probe.Prompt, response,
aiCfg.Scorer.SycophancyEcho, aiCfg.Scorer.SycophancyUplift)
tokPerSec := 0.0
if result.Metrics != nil {
tokPerSec = result.Metrics.DecodeTokensPerSec
}
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
len(response), grammar.Composite,
delta.Uplift, delta.Echo, delta.Enrichment,
elapsed.Seconds(), tokPerSec)
candidate := &distillCandidate{
Response: response,
Grammar: grammar,
Delta: delta,
Elapsed: elapsed,
}
if best == nil || grammar.Composite > best.Grammar.Composite {
best = candidate
}
}
// Optional attention scoring — costs an extra prefill per probe.
if best != nil && aiCfg.Scorer.Attention {
snap, attErr := backend.InspectAttention(ctx, probe.Prompt)
if attErr == nil {
attResult := AnalyseAttention(snap)
boScore := attResult.Composite()
fmt.Fprintf(os.Stderr, " BO: coherence=%.2f phase=%.2f cross=%.2f composite=%d\n",
attResult.MeanCoherence, attResult.PhaseLockScore, attResult.MeanCrossAlignment, boScore)
if aiCfg.Scorer.AttentionMinScore > 0 && boScore < aiCfg.Scorer.AttentionMinScore {
skipped++
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (BO composite %d < %d)\n",
probe.ID, boScore, aiCfg.Scorer.AttentionMinScore)
continue
}
}
}
// Quality gate.
if best != nil && best.Grammar.Composite >= minScore {
// Duplicate filter: reject if grammar profile is too similar to an already-kept entry.
bestFeatures := GrammarFeatures(best.Grammar)
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
deduped++
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
continue
}
// Save with output prompt — sandwich if kernel, bare if LEM model.
example := TrainingExample{
Messages: []ChatMessage{
{Role: "user", Content: outputPrompt},
{Role: "assistant", Content: best.Response},
},
}
line, _ := json.Marshal(example)
out.Write(append(line, '\n'))
// Add to dedup index.
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
if dedupIdx == nil {
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
} else {
_ = dedupIdx.Insert(entry)
}
kept++
fmt.Fprintf(os.Stderr, " ✓ KEPT %s (g=%.1f, verbs=%d, nouns=%d, enr=%+.1f)\n",
probe.ID, best.Grammar.Composite,
best.Grammar.VerbDiversity, best.Grammar.NounDiversity,
best.Delta.Enrichment)
} else {
skipped++
score := 0.0
if best != nil {
score = best.Grammar.Composite
}
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
probe.ID, score, minScore)
}
}
duration := time.Since(totalStart)
fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n")
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes))
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", minScore)
fmt.Fprintf(os.Stderr, "Kept: %d\n", kept)
fmt.Fprintf(os.Stderr, "Deduped: %d\n", deduped)
fmt.Fprintf(os.Stderr, "Skipped: %d\n", skipped)
total := kept + deduped + skipped
if total > 0 {
fmt.Fprintf(os.Stderr, "Pass rate: %.0f%%\n", float64(kept)/float64(total)*100)
}
fmt.Fprintf(os.Stderr, "Output: %s\n", outputPath)
fmt.Fprintf(os.Stderr, "Duration: %.0fs (%.1fm)\n", duration.Seconds(), duration.Minutes())
return nil
}
// loadDistillProbes loads probes from a named set or a file path.
// Returns the probes and the default phase number for output routing.
func loadDistillProbes(root, spec string) ([]DistillProbe, int, error) {
// Try as a probe set name first.
probesCfg, cfgErr := LoadProbesConfig(root)
if cfgErr == nil {
if set, ok := probesCfg.Sets[spec]; ok {
phase := 0
if set.Phase != nil {
phase = *set.Phase
}
var probes []DistillProbe
for _, f := range set.Files {
// Files are relative to the training root.
ps, err := readProbeFile(filepath.Join(root, "training", "lem", f))
if err != nil {
return nil, 0, fmt.Errorf("read %s: %w", f, err)
}
probes = append(probes, ps...)
}
return probes, phase, nil
}
}
// Fall back to direct file path.
probes, err := readProbeFile(spec)
if err != nil {
return nil, 0, err
}
return probes, 0, nil
}
// readProbeFile reads probes from a JSON array file.
func readProbeFile(path string) ([]DistillProbe, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var raw []struct {
ID string `json:"id"`
Domain string `json:"domain"`
Prompt string `json:"prompt"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parse %s: %w", filepath.Base(path), err)
}
probes := make([]DistillProbe, len(raw))
for i, r := range raw {
probes[i] = DistillProbe{
ID: r.ID,
Domain: r.Domain,
Prompt: r.Prompt,
Source: filepath.Base(path),
}
}
return probes, nil
}