diff --git a/cmd/lemcmd/score.go b/cmd/lemcmd/score.go index 509110d..cf25317 100644 --- a/cmd/lemcmd/score.go +++ b/cmd/lemcmd/score.go @@ -27,6 +27,7 @@ func addScoreCommands(root *cli.Command) { cli.StringFlag(compareCmd, &compareNew, "new", "", "", "New score file (required)") scoreGroup.AddCommand(compareCmd) + scoreGroup.AddCommand(passthrough("attention", "Q/K Bone Orientation analysis for a prompt", lem.RunAttention)) scoreGroup.AddCommand(passthrough("tier", "Score expansion responses (heuristic/judge tiers)", lem.RunTierScore)) scoreGroup.AddCommand(passthrough("agent", "ROCm scoring daemon (polls M3, scores checkpoints)", lem.RunAgent)) diff --git a/pkg/lem/cmd_attention.go b/pkg/lem/cmd_attention.go new file mode 100644 index 0000000..3f4ec5d --- /dev/null +++ b/pkg/lem/cmd_attention.go @@ -0,0 +1,110 @@ +package lem + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "time" + + "forge.lthn.ai/core/go-ml" + "forge.lthn.ai/core/go-mlx" +) + +// RunAttention is the entry point for `lem score attention`. +func RunAttention(args []string) { + fs := flag.NewFlagSet("attention", flag.ExitOnError) + var ( + modelFlag string + prompt string + jsonOutput bool + cacheLimit int + memLimit int + root string + ) + fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)") + fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse") + fs.BoolVar(&jsonOutput, "json", false, "Output as JSON") + fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)") + fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)") + fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if prompt == "" { + fmt.Fprintln(os.Stderr, "usage: lem score attention -model -prompt ") + os.Exit(1) + } + + // Load configs. + aiCfg, err := LoadAIConfig(root) + if err != nil { + log.Fatalf("load ai config: %v", err) + } + + modelCfg, err := LoadModelConfig(root, modelFlag) + if err != nil { + log.Fatalf("load model config: %v", err) + } + + // Apply memory limits. + cacheLimitGB := aiCfg.Distill.CacheLimit + if cacheLimit > 0 { + cacheLimitGB = cacheLimit + } + memLimitGB := aiCfg.Distill.MemoryLimit + if memLimit > 0 { + memLimitGB = memLimit + } + if cacheLimitGB > 0 { + mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024) + } + if memLimitGB > 0 { + mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024) + } + + // Load model. + log.Printf("loading model: %s", modelCfg.Paths.Base) + backend, err := ml.NewMLXBackend(modelCfg.Paths.Base) + if err != nil { + log.Fatalf("load model: %v", err) + } + defer backend.Close() + + // Run attention inspection. + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + snap, err := backend.InspectAttention(ctx, prompt) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + result := AnalyseAttention(snap) + + if jsonOutput { + data, _ := json.MarshalIndent(result, "", " ") + fmt.Println(string(data)) + return + } + + // Human-readable output. + fmt.Printf("Q/K Bone Orientation Analysis\n") + fmt.Printf(" Architecture: %s\n", snap.Architecture) + fmt.Printf(" Layers: %d\n", snap.NumLayers) + fmt.Printf(" KV Heads: %d\n", snap.NumHeads) + fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen) + fmt.Printf(" Head Dimension: %d\n", snap.HeadDim) + fmt.Println() + fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence) + fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment) + fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy) + fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore) + fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount) + fmt.Printf(" Composite (0-100): %.1f\n", result.Composite()) +}