1
0
Fork 0
forked from lthn/LEM
LEM/pkg/lem/cmd_attention.go
Snider 56eda1a081 refactor: migrate all 25 commands from passthrough to cobra framework
Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper
cobra integration. Every Run* function now takes a typed *Opts struct
and returns error. Flags registered via cli.StringFlag/IntFlag/etc.
Commands participate in Core lifecycle with full cobra flag parsing.

- 6 command groups: gen, score, data, export, infra, mon
- 25 commands converted, 0 passthrough() calls remain
- Delete passthrough() helper from lem.go
- Update export_test.go to use ExportOpts struct

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 03:32:53 +00:00

97 lines
2.8 KiB
Go

package lem
import (
"context"
"encoding/json"
"fmt"
"log"
"time"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
// AttentionOpts holds configuration for the attention command.
type AttentionOpts struct {
Model string // Model config path (relative to .core/ai/models/)
Prompt string // Prompt text to analyse
JSON bool // Output as JSON
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
Root string // Project root (for .core/ai/ config)
}
// RunAttention is the entry point for `lem score attention`.
func RunAttention(cfg AttentionOpts) error {
if cfg.Prompt == "" {
return fmt.Errorf("--prompt is required")
}
// Load configs.
aiCfg, err := LoadAIConfig(cfg.Root)
if err != nil {
return fmt.Errorf("load ai config: %w", err)
}
modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model)
if err != nil {
return fmt.Errorf("load model config: %w", err)
}
// Apply memory limits.
cacheLimitGB := aiCfg.Distill.CacheLimit
if cfg.CacheLimit > 0 {
cacheLimitGB = cfg.CacheLimit
}
memLimitGB := aiCfg.Distill.MemoryLimit
if cfg.MemLimit > 0 {
memLimitGB = cfg.MemLimit
}
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
}
// Load model.
log.Printf("loading model: %s", modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
defer backend.Close()
// Run attention inspection.
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
snap, err := backend.InspectAttention(ctx, cfg.Prompt)
if err != nil {
return fmt.Errorf("inspect attention: %w", err)
}
result := AnalyseAttention(snap)
if cfg.JSON {
data, _ := json.MarshalIndent(result, "", " ")
fmt.Println(string(data))
return nil
}
// Human-readable output.
fmt.Printf("Q/K Bone Orientation Analysis\n")
fmt.Printf(" Architecture: %s\n", snap.Architecture)
fmt.Printf(" Layers: %d\n", snap.NumLayers)
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
fmt.Println()
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
fmt.Printf(" Composite (0-10000): %d\n", result.Composite())
return nil
}