110 lines
3 KiB
Go
110 lines
3 KiB
Go
package lem
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"time"
|
|
|
|
"forge.lthn.ai/core/go-ml"
|
|
"forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
// RunAttention is the entry point for `lem score attention`.
|
|
func RunAttention(args []string) {
|
|
fs := flag.NewFlagSet("attention", flag.ExitOnError)
|
|
var (
|
|
modelFlag string
|
|
prompt string
|
|
jsonOutput bool
|
|
cacheLimit int
|
|
memLimit int
|
|
root string
|
|
)
|
|
fs.StringVar(&modelFlag, "model", "gemma3/1b", "Model config path (relative to .core/ai/models/)")
|
|
fs.StringVar(&prompt, "prompt", "", "Prompt text to analyse")
|
|
fs.BoolVar(&jsonOutput, "json", false, "Output as JSON")
|
|
fs.IntVar(&cacheLimit, "cache-limit", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
|
fs.IntVar(&memLimit, "mem-limit", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
|
fs.StringVar(&root, "root", ".", "Project root (for .core/ai/ config)")
|
|
|
|
if err := fs.Parse(args); err != nil {
|
|
log.Fatalf("parse flags: %v", err)
|
|
}
|
|
|
|
if prompt == "" {
|
|
fmt.Fprintln(os.Stderr, "usage: lem score attention -model <path> -prompt <text>")
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Load configs.
|
|
aiCfg, err := LoadAIConfig(root)
|
|
if err != nil {
|
|
log.Fatalf("load ai config: %v", err)
|
|
}
|
|
|
|
modelCfg, err := LoadModelConfig(root, modelFlag)
|
|
if err != nil {
|
|
log.Fatalf("load model config: %v", err)
|
|
}
|
|
|
|
// Apply memory limits.
|
|
cacheLimitGB := aiCfg.Distill.CacheLimit
|
|
if cacheLimit > 0 {
|
|
cacheLimitGB = cacheLimit
|
|
}
|
|
memLimitGB := aiCfg.Distill.MemoryLimit
|
|
if memLimit > 0 {
|
|
memLimitGB = memLimit
|
|
}
|
|
if cacheLimitGB > 0 {
|
|
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
|
}
|
|
if memLimitGB > 0 {
|
|
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
|
}
|
|
|
|
// Load model.
|
|
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
|
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
|
if err != nil {
|
|
log.Fatalf("load model: %v", err)
|
|
}
|
|
defer backend.Close()
|
|
|
|
// Run attention inspection.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
snap, err := backend.InspectAttention(ctx, prompt)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
result := AnalyseAttention(snap)
|
|
|
|
if jsonOutput {
|
|
data, _ := json.MarshalIndent(result, "", " ")
|
|
fmt.Println(string(data))
|
|
return
|
|
}
|
|
|
|
// Human-readable output.
|
|
fmt.Printf("Q/K Bone Orientation Analysis\n")
|
|
fmt.Printf(" Architecture: %s\n", snap.Architecture)
|
|
fmt.Printf(" Layers: %d\n", snap.NumLayers)
|
|
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
|
|
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
|
|
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
|
|
fmt.Println()
|
|
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
|
|
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
|
|
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
|
|
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
|
|
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
|
|
fmt.Printf(" Composite (0-100): %.1f\n", result.Composite())
|
|
}
|