feat: add 'lem score attention' CLI for Q/K Bone Orientation analysis
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
28309b26dc
commit
e3331920c4
2 changed files with 111 additions and 0 deletions
|
|
@ -27,6 +27,7 @@ func addScoreCommands(root *cli.Command) {
|
|||
cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)")
|
||||
scoreGroup.AddCommand(compareCmd)
|
||||
|
||||
scoreGroup.AddCommand(passthrough("attention", "Q/K Bone Orientation analysis for a prompt", lem.RunAttention))
|
||||
scoreGroup.AddCommand(passthrough("tier", "Score expansion responses (heuristic/judge tiers)", lem.RunTierScore))
|
||||
scoreGroup.AddCommand(passthrough("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", lem.RunAgent))
|
||||
|
||||
|
|
|
|||
110
pkg/lem/cmd_attention.go
Normal file
110
pkg/lem/cmd_attention.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
// RunAttention is the entry point for `lem score attention`.
|
||||
func RunAttention(args []string) {
|
||||
fs := flag.NewFlagSet("attention", flag.ExitOnError)
|
||||
var (
|
||||
modelFlag string
|
||||
prompt string
|
||||
jsonOutput bool
|
||||
cacheLimit int
|
||||
memLimit int
|
||||
root string
|
||||
)
|
||||
fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
|
||||
fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse")
|
||||
fs.BoolVar(&jsonOutput, "json", false, "Output as JSON")
|
||||
fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if prompt == "" {
|
||||
fmt.Fprintln(os.Stderr, "usage: lem score attention -model <path> -prompt <text>")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Load configs.
|
||||
aiCfg, err := LoadAIConfig(root)
|
||||
if err != nil {
|
||||
log.Fatalf("load ai config: %v", err)
|
||||
}
|
||||
|
||||
modelCfg, err := LoadModelConfig(root, modelFlag)
|
||||
if err != nil {
|
||||
log.Fatalf("load model config: %v", err)
|
||||
}
|
||||
|
||||
// Apply memory limits.
|
||||
cacheLimitGB := aiCfg.Distill.CacheLimit
|
||||
if cacheLimit > 0 {
|
||||
cacheLimitGB = cacheLimit
|
||||
}
|
||||
memLimitGB := aiCfg.Distill.MemoryLimit
|
||||
if memLimit > 0 {
|
||||
memLimitGB = memLimit
|
||||
}
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
}
|
||||
if memLimitGB > 0 {
|
||||
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
||||
}
|
||||
|
||||
// Load model.
|
||||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
if err != nil {
|
||||
log.Fatalf("load model: %v", err)
|
||||
}
|
||||
defer backend.Close()
|
||||
|
||||
// Run attention inspection.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
snap, err := backend.InspectAttention(ctx, prompt)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
result := AnalyseAttention(snap)
|
||||
|
||||
if jsonOutput {
|
||||
data, _ := json.MarshalIndent(result, "", " ")
|
||||
fmt.Println(string(data))
|
||||
return
|
||||
}
|
||||
|
||||
// Human-readable output.
|
||||
fmt.Printf("Q/K Bone Orientation Analysis\n")
|
||||
fmt.Printf(" Architecture: %s\n", snap.Architecture)
|
||||
fmt.Printf(" Layers: %d\n", snap.NumLayers)
|
||||
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
|
||||
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
|
||||
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
|
||||
fmt.Println()
|
||||
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
|
||||
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
|
||||
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
|
||||
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
|
||||
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
|
||||
fmt.Printf(" Composite (0-100): %.1f\n", result.Composite())
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue