feat: add 'lem score attention' CLI for Q/K Bone Orientation analysis

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 00:29:41 +00:00
parent 28309b26dc
commit e3331920c4
2 changed files with 111 additions and 0 deletions

View file

@ -27,6 +27,7 @@ func addScoreCommands(root *cli.Command) {
cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)")
scoreGroup.AddCommand(compareCmd)
scoreGroup.AddCommand(passthrough("attention", "Q/K Bone Orientation analysis for a prompt", lem.RunAttention))
scoreGroup.AddCommand(passthrough("tier", "Score expansion responses (heuristic/judge tiers)", lem.RunTierScore))
scoreGroup.AddCommand(passthrough("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", lem.RunAgent))

110
pkg/lem/cmd_attention.go Normal file
View file

@ -0,0 +1,110 @@
package lem
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"time"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
// RunAttention is the entry point for `lem score attention`.
func RunAttention(args []string) {
fs := flag.NewFlagSet("attention", flag.ExitOnError)
var (
modelFlag string
prompt string
jsonOutput bool
cacheLimit int
memLimit int
root string
)
fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse")
fs.BoolVar(&jsonOutput, "json", false, "Output as JSON")
fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if prompt == "" {
fmt.Fprintln(os.Stderr, "usage: lem score attention -model <path> -prompt <text>")
os.Exit(1)
}
// Load configs.
aiCfg, err := LoadAIConfig(root)
if err != nil {
log.Fatalf("load ai config: %v", err)
}
modelCfg, err := LoadModelConfig(root, modelFlag)
if err != nil {
log.Fatalf("load model config: %v", err)
}
// Apply memory limits.
cacheLimitGB := aiCfg.Distill.CacheLimit
if cacheLimit > 0 {
cacheLimitGB = cacheLimit
}
memLimitGB := aiCfg.Distill.MemoryLimit
if memLimit > 0 {
memLimitGB = memLimit
}
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
}
// Load model.
log.Printf("loading model: %s", modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
if err != nil {
log.Fatalf("load model: %v", err)
}
defer backend.Close()
// Run attention inspection.
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
snap, err := backend.InspectAttention(ctx, prompt)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
result := AnalyseAttention(snap)
if jsonOutput {
data, _ := json.MarshalIndent(result, "", " ")
fmt.Println(string(data))
return
}
// Human-readable output.
fmt.Printf("Q/K Bone Orientation Analysis\n")
fmt.Printf(" Architecture: %s\n", snap.Architecture)
fmt.Printf(" Layers: %d\n", snap.NumLayers)
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
fmt.Println()
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
fmt.Printf(" Composite (0-100): %.1f\n", result.Composite())
}