LEM/pkg/lem/cmd_attention.go
Snider e3331920c4 feat: add 'lem score attention' CLI for Q/K Bone Orientation analysis
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 00:29:41 +00:00

110 lines
3 KiB
Go

package lem
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"time"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
// RunAttention is the entry point for `lem score attention`.
func RunAttention(args []string) {
fs := flag.NewFlagSet("attention", flag.ExitOnError)
var (
modelFlag string
prompt string
jsonOutput bool
cacheLimit int
memLimit int
root string
)
fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse")
fs.BoolVar(&jsonOutput, "json", false, "Output as JSON")
fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if prompt == "" {
fmt.Fprintln(os.Stderr, "usage: lem score attention -model <path> -prompt <text>")
os.Exit(1)
}
// Load configs.
aiCfg, err := LoadAIConfig(root)
if err != nil {
log.Fatalf("load ai config: %v", err)
}
modelCfg, err := LoadModelConfig(root, modelFlag)
if err != nil {
log.Fatalf("load model config: %v", err)
}
// Apply memory limits.
cacheLimitGB := aiCfg.Distill.CacheLimit
if cacheLimit > 0 {
cacheLimitGB = cacheLimit
}
memLimitGB := aiCfg.Distill.MemoryLimit
if memLimit > 0 {
memLimitGB = memLimit
}
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
}
// Load model.
log.Printf("loading model: %s", modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
if err != nil {
log.Fatalf("load model: %v", err)
}
defer backend.Close()
// Run attention inspection.
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
snap, err := backend.InspectAttention(ctx, prompt)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
result := AnalyseAttention(snap)
if jsonOutput {
data, _ := json.MarshalIndent(result, "", " ")
fmt.Println(string(data))
return
}
// Human-readable output.
fmt.Printf("Q/K Bone Orientation Analysis\n")
fmt.Printf(" Architecture: %s\n", snap.Architecture)
fmt.Printf(" Layers: %d\n", snap.NumLayers)
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
fmt.Println()
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
fmt.Printf(" Composite (0-100): %.1f\n", result.Composite())
}