- Add backend_mlxlm.go blank import to register mlx-lm subprocess backend - Select backend from ai.yaml config (metal, mlx_lm, rocm, api) - Only set Metal cache/memory limits when using metal backend - Add --no-dedup flag to disable grammar-profile deduplication (trained models with consistent voice trigger false positives at 0.02) - Add --context-len flag and context_len config for KV cache sizing - Pass WithBackend and WithContextLen to go-ml backend loader Co-Authored-By: Virgil <virgil@lethean.io>
446 lines
13 KiB
Go
446 lines
13 KiB
Go
package lem
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"forge.lthn.ai/core/go-i18n/reversal"
|
|
"forge.lthn.ai/core/go-inference"
|
|
"forge.lthn.ai/core/go-ml"
|
|
"forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
// DistillProbe is a single input prompt for distillation.
|
|
type DistillProbe struct {
|
|
ID string `json:"id"`
|
|
Domain string `json:"domain,omitempty"`
|
|
Prompt string `json:"prompt"`
|
|
Source string `json:"-"`
|
|
}
|
|
|
|
// distillCandidate holds a single generation attempt with its scores.
|
|
type distillCandidate struct {
|
|
Response string
|
|
Grammar GrammarScore
|
|
Delta DeltaScore
|
|
Elapsed time.Duration
|
|
}
|
|
|
|
// DistillOpts holds configuration for the distill command.
|
|
type DistillOpts struct {
|
|
Model string // Model config path (relative to .core/ai/models/)
|
|
Probes string // Probe set name from probes.yaml, or path to JSON file
|
|
Output string // Output JSONL path (defaults to model training dir)
|
|
Lesson int // Lesson number to append to (defaults to probe set phase)
|
|
MinScore float64 // Min grammar composite (0 = use ai.yaml default)
|
|
Runs int // Generations per probe (0 = use ai.yaml default)
|
|
DryRun bool // Show plan and exit without generating
|
|
NoDedup bool // Disable grammar-profile deduplication
|
|
Root string // Project root (for .core/ai/ config)
|
|
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
|
|
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
|
|
ContextLen int // KV cache context window (0 = auto: max_tokens * 2)
|
|
}
|
|
|
|
// RunDistill is the CLI entry point for the distill command.
|
|
// Generates responses via native Metal inference, scores with go-i18n/reversal,
|
|
// writes passing examples as training JSONL.
|
|
func RunDistill(cfg DistillOpts) error {
|
|
// Load configs.
|
|
aiCfg, err := LoadAIConfig(cfg.Root)
|
|
if err != nil {
|
|
return fmt.Errorf("load ai config: %w", err)
|
|
}
|
|
|
|
// Apply config defaults for model and probes.
|
|
model := cfg.Model
|
|
if model == "" {
|
|
model = aiCfg.Distill.Model
|
|
}
|
|
if model == "" {
|
|
return fmt.Errorf("no model specified (use --model or set distill.model in ai.yaml)")
|
|
}
|
|
|
|
modelCfg, err := LoadModelConfig(cfg.Root, model)
|
|
if err != nil {
|
|
return fmt.Errorf("load model config: %w", err)
|
|
}
|
|
|
|
genCfg := MergeGenerate(aiCfg.Generate, modelCfg.Generate)
|
|
|
|
// Apply flag overrides.
|
|
minScore := cfg.MinScore
|
|
if minScore == 0 {
|
|
minScore = aiCfg.Scorer.MinScore
|
|
}
|
|
runs := cfg.Runs
|
|
if runs == 0 {
|
|
runs = aiCfg.Distill.Runs
|
|
}
|
|
cacheLimitGB := aiCfg.Distill.CacheLimit
|
|
if cfg.CacheLimit > 0 {
|
|
cacheLimitGB = cfg.CacheLimit
|
|
}
|
|
memLimitGB := aiCfg.Distill.MemoryLimit
|
|
if cfg.MemLimit > 0 {
|
|
memLimitGB = cfg.MemLimit
|
|
}
|
|
contextLen := aiCfg.Distill.ContextLen
|
|
if cfg.ContextLen > 0 {
|
|
contextLen = cfg.ContextLen
|
|
}
|
|
if contextLen == 0 {
|
|
// Auto: 2x max_tokens covers kernel + prompt + generation.
|
|
contextLen = genCfg.MaxTokens * 2
|
|
}
|
|
|
|
// Load probes.
|
|
probeSet := cfg.Probes
|
|
if probeSet == "" {
|
|
probeSet = aiCfg.Distill.Probes
|
|
}
|
|
if probeSet == "" {
|
|
return fmt.Errorf("no probe set specified (use --probes or set distill.probes in ai.yaml)")
|
|
}
|
|
probes, phase, err := loadDistillProbes(cfg.Root, probeSet)
|
|
if err != nil {
|
|
return fmt.Errorf("load probes: %w", err)
|
|
}
|
|
log.Printf("loaded %d probes", len(probes))
|
|
|
|
// Determine output path.
|
|
outputPath := cfg.Output
|
|
if outputPath == "" {
|
|
lesson := cfg.Lesson
|
|
if lesson < 0 {
|
|
lesson = phase
|
|
}
|
|
lessonFile, ok := modelCfg.Lessons[lesson]
|
|
if !ok {
|
|
lessonFile = fmt.Sprintf("lesson-%d.jsonl", lesson)
|
|
}
|
|
outputPath = filepath.Join(modelCfg.Training, lessonFile)
|
|
}
|
|
|
|
// Load kernel and signature — only if the model config specifies them.
|
|
// LEM models have axioms in weights; loading the kernel violates A4.
|
|
var kernel []byte
|
|
var sig string
|
|
if modelCfg.Kernel != "" {
|
|
kernel, err = os.ReadFile(modelCfg.Kernel)
|
|
if err != nil {
|
|
return fmt.Errorf("read kernel: %w", err)
|
|
}
|
|
log.Printf("kernel: %d chars from %s (sandwich output)", len(kernel), modelCfg.Kernel)
|
|
|
|
if modelCfg.Signature != "" {
|
|
sigBytes, err := os.ReadFile(modelCfg.Signature)
|
|
if err != nil {
|
|
return fmt.Errorf("read signature: %w", err)
|
|
}
|
|
sig = strings.TrimSpace(string(sigBytes))
|
|
}
|
|
} else {
|
|
log.Printf("no kernel configured — bare output (LEM model)")
|
|
}
|
|
|
|
// Dry run.
|
|
if cfg.DryRun {
|
|
fmt.Printf("Model: %s (%s)\n", modelCfg.Name, modelCfg.Paths.Base)
|
|
fmt.Printf("Backend: %s\n", aiCfg.Backend)
|
|
fmt.Printf("Probes: %d\n", len(probes))
|
|
fmt.Printf("Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
|
|
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore)
|
|
fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n",
|
|
genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP)
|
|
fmt.Printf("Memory: cache=%dGB limit=%dGB context_len=%d\n", cacheLimitGB, memLimitGB, contextLen)
|
|
fmt.Printf("Output: %s\n", outputPath)
|
|
fmt.Println()
|
|
for i, p := range probes {
|
|
if i >= 10 {
|
|
fmt.Printf(" ... and %d more\n", len(probes)-10)
|
|
break
|
|
}
|
|
prompt := p.Prompt
|
|
if len(prompt) > 80 {
|
|
prompt = prompt[:80] + "..."
|
|
}
|
|
fmt.Printf(" %s: %s\n", p.ID, prompt)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Set Metal memory limits (only relevant for native Metal backend).
|
|
backendName := aiCfg.Backend
|
|
if backendName == "" {
|
|
backendName = "metal"
|
|
}
|
|
if backendName == "metal" {
|
|
if cacheLimitGB > 0 {
|
|
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
|
log.Printf("metal cache limit: %dGB", cacheLimitGB)
|
|
}
|
|
if memLimitGB > 0 {
|
|
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
|
log.Printf("metal memory limit: %dGB", memLimitGB)
|
|
}
|
|
}
|
|
|
|
// Load model via go-ml Backend (wraps go-inference with backend selection).
|
|
// WithContextLen bounds the KV cache for native Metal; mlx_lm handles its own memory.
|
|
log.Printf("loading model: %s (backend=%s, context_len=%d)", modelCfg.Paths.Base, backendName, contextLen)
|
|
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base,
|
|
inference.WithBackend(backendName),
|
|
inference.WithContextLen(contextLen),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("load model: %w", err)
|
|
}
|
|
defer backend.Close()
|
|
|
|
log.Printf("model loaded via %s backend", backend.Name())
|
|
|
|
// Build generation options from merged config.
|
|
genOpts := ml.GenOpts{
|
|
MaxTokens: genCfg.MaxTokens,
|
|
Temperature: genCfg.Temperature,
|
|
TopP: genCfg.TopP,
|
|
TopK: genCfg.TopK,
|
|
RepeatPenalty: genCfg.RepeatPenalty,
|
|
}
|
|
|
|
// Initialise grammar scorer.
|
|
tok := reversal.NewTokeniser()
|
|
|
|
// Open output for append.
|
|
out, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("open output: %w", err)
|
|
}
|
|
defer out.Close()
|
|
|
|
kept := 0
|
|
skipped := 0
|
|
deduped := 0
|
|
totalStart := time.Now()
|
|
ctx := context.Background()
|
|
kernelStr := strings.TrimSpace(string(kernel))
|
|
|
|
// Running duplicate index for grammar-profile deduplication.
|
|
var dedupIdx *ScoreIndex
|
|
|
|
for i, probe := range probes {
|
|
var best *distillCandidate
|
|
|
|
// Build output prompt: sandwich if kernel configured, bare if LEM model.
|
|
var outputPrompt string
|
|
if kernelStr != "" {
|
|
outputPrompt = kernelStr + "\n\n" + probe.Prompt
|
|
if sig != "" {
|
|
outputPrompt += "\n\n" + sig
|
|
}
|
|
} else {
|
|
outputPrompt = probe.Prompt
|
|
}
|
|
|
|
for run := range runs {
|
|
fmt.Fprintf(os.Stderr, " [%d/%d] %s run %d/%d",
|
|
i+1, len(probes), probe.ID, run+1, runs)
|
|
|
|
// Inference uses bare probe — the model generates from its weights.
|
|
// Sandwich wrapping is only for the training output format.
|
|
messages := []ml.Message{
|
|
{Role: "user", Content: probe.Prompt},
|
|
}
|
|
|
|
// Generate via go-ml Backend (memory-managed Metal inference).
|
|
start := time.Now()
|
|
result, err := backend.Chat(ctx, messages, genOpts)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, " → ERROR: %v\n", err)
|
|
continue
|
|
}
|
|
response := result.Text
|
|
elapsed := time.Since(start)
|
|
|
|
// Quick reject: empty/degenerate.
|
|
if len(strings.TrimSpace(response)) < aiCfg.Distill.MinChars {
|
|
fmt.Fprintf(os.Stderr, " → %d chars, EMPTY, %.1fs\n", len(response), elapsed.Seconds())
|
|
continue
|
|
}
|
|
|
|
// Score with go-i18n/reversal.
|
|
grammar := ScoreResponse(tok, response)
|
|
delta := ComputeDelta(tok, probe.Prompt, response,
|
|
aiCfg.Scorer.SycophancyEcho, aiCfg.Scorer.SycophancyUplift)
|
|
|
|
tokPerSec := 0.0
|
|
if result.Metrics != nil {
|
|
tokPerSec = result.Metrics.DecodeTokensPerSec
|
|
}
|
|
fmt.Fprintf(os.Stderr, " → %d chars, g=%.1f up=%+.1f echo=%.2f enr=%+.1f, %.1fs (%.0f tok/s)\n",
|
|
len(response), grammar.Composite,
|
|
delta.Uplift, delta.Echo, delta.Enrichment,
|
|
elapsed.Seconds(), tokPerSec)
|
|
|
|
candidate := &distillCandidate{
|
|
Response: response,
|
|
Grammar: grammar,
|
|
Delta: delta,
|
|
Elapsed: elapsed,
|
|
}
|
|
|
|
if best == nil || grammar.Composite > best.Grammar.Composite {
|
|
best = candidate
|
|
}
|
|
}
|
|
|
|
// Optional attention scoring — costs an extra prefill per probe.
|
|
if best != nil && aiCfg.Scorer.Attention {
|
|
snap, attErr := backend.InspectAttention(ctx, probe.Prompt)
|
|
if attErr == nil {
|
|
attResult := AnalyseAttention(snap)
|
|
boScore := attResult.Composite()
|
|
fmt.Fprintf(os.Stderr, " BO: coherence=%.2f phase=%.2f cross=%.2f composite=%d\n",
|
|
attResult.MeanCoherence, attResult.PhaseLockScore, attResult.MeanCrossAlignment, boScore)
|
|
if aiCfg.Scorer.AttentionMinScore > 0 && boScore < aiCfg.Scorer.AttentionMinScore {
|
|
skipped++
|
|
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (BO composite %d < %d)\n",
|
|
probe.ID, boScore, aiCfg.Scorer.AttentionMinScore)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
// Quality gate.
|
|
if best != nil && best.Grammar.Composite >= minScore {
|
|
// Duplicate filter: reject if grammar profile is too similar to an already-kept entry.
|
|
if !cfg.NoDedup {
|
|
bestFeatures := GrammarFeatures(best.Grammar)
|
|
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
|
|
deduped++
|
|
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
|
|
continue
|
|
}
|
|
}
|
|
|
|
// Save with output prompt — sandwich if kernel, bare if LEM model.
|
|
example := TrainingExample{
|
|
Messages: []ChatMessage{
|
|
{Role: "user", Content: outputPrompt},
|
|
{Role: "assistant", Content: best.Response},
|
|
},
|
|
}
|
|
line, _ := json.Marshal(example)
|
|
out.Write(append(line, '\n'))
|
|
|
|
// Add to dedup index (unless disabled).
|
|
if !cfg.NoDedup {
|
|
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
|
|
if dedupIdx == nil {
|
|
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
|
|
} else {
|
|
_ = dedupIdx.Insert(entry)
|
|
}
|
|
}
|
|
|
|
kept++
|
|
fmt.Fprintf(os.Stderr, " ✓ KEPT %s (g=%.1f, verbs=%d, nouns=%d, enr=%+.1f)\n",
|
|
probe.ID, best.Grammar.Composite,
|
|
best.Grammar.VerbDiversity, best.Grammar.NounDiversity,
|
|
best.Delta.Enrichment)
|
|
} else {
|
|
skipped++
|
|
score := 0.0
|
|
if best != nil {
|
|
score = best.Grammar.Composite
|
|
}
|
|
fmt.Fprintf(os.Stderr, " ✗ SKIP %s (best g=%.1f < %.1f)\n",
|
|
probe.ID, score, minScore)
|
|
}
|
|
}
|
|
|
|
duration := time.Since(totalStart)
|
|
|
|
fmt.Fprintf(os.Stderr, "\n=== Distillation Complete ===\n")
|
|
fmt.Fprintf(os.Stderr, "Model: %s (%s)\n", modelCfg.Name, backend.Name())
|
|
fmt.Fprintf(os.Stderr, "Probes: %d\n", len(probes))
|
|
fmt.Fprintf(os.Stderr, "Runs: %d per probe (%d total generations)\n", runs, len(probes)*runs)
|
|
fmt.Fprintf(os.Stderr, "Scorer: go-i18n/reversal grammar v3, gate >= %.1f\n", minScore)
|
|
fmt.Fprintf(os.Stderr, "Kept: %d\n", kept)
|
|
fmt.Fprintf(os.Stderr, "Deduped: %d\n", deduped)
|
|
fmt.Fprintf(os.Stderr, "Skipped: %d\n", skipped)
|
|
total := kept + deduped + skipped
|
|
if total > 0 {
|
|
fmt.Fprintf(os.Stderr, "Pass rate: %.0f%%\n", float64(kept)/float64(total)*100)
|
|
}
|
|
fmt.Fprintf(os.Stderr, "Output: %s\n", outputPath)
|
|
fmt.Fprintf(os.Stderr, "Duration: %.0fs (%.1fm)\n", duration.Seconds(), duration.Minutes())
|
|
return nil
|
|
}
|
|
|
|
// loadDistillProbes loads probes from a named set or a file path.
|
|
// Returns the probes and the default phase number for output routing.
|
|
func loadDistillProbes(root, spec string) ([]DistillProbe, int, error) {
|
|
// Try as a probe set name first.
|
|
probesCfg, cfgErr := LoadProbesConfig(root)
|
|
if cfgErr == nil {
|
|
if set, ok := probesCfg.Sets[spec]; ok {
|
|
phase := 0
|
|
if set.Phase != nil {
|
|
phase = *set.Phase
|
|
}
|
|
var probes []DistillProbe
|
|
for _, f := range set.Files {
|
|
// Files are relative to the training root.
|
|
ps, err := readProbeFile(filepath.Join(root, "training", "lem", f))
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("read %s: %w", f, err)
|
|
}
|
|
probes = append(probes, ps...)
|
|
}
|
|
return probes, phase, nil
|
|
}
|
|
}
|
|
|
|
// Fall back to direct file path.
|
|
probes, err := readProbeFile(spec)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return probes, 0, nil
|
|
}
|
|
|
|
// readProbeFile reads probes from a JSON array file.
|
|
func readProbeFile(path string) ([]DistillProbe, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var raw []struct {
|
|
ID string `json:"id"`
|
|
Domain string `json:"domain"`
|
|
Prompt string `json:"prompt"`
|
|
}
|
|
if err := json.Unmarshal(data, &raw); err != nil {
|
|
return nil, fmt.Errorf("parse %s: %w", filepath.Base(path), err)
|
|
}
|
|
|
|
probes := make([]DistillProbe, len(raw))
|
|
for i, r := range raw {
|
|
probes[i] = DistillProbe{
|
|
ID: r.ID,
|
|
Domain: r.Domain,
|
|
Prompt: r.Prompt,
|
|
Source: filepath.Base(path),
|
|
}
|
|
}
|
|
return probes, nil
|
|
}
|