Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper cobra integration. Every Run* function now takes a typed *Opts struct and returns error. Flags registered via cli.StringFlag/IntFlag/etc. Commands participate in Core lifecycle with full cobra flag parsing. - 6 command groups: gen, score, data, export, infra, mon - 25 commands converted, 0 passthrough() calls remain - Delete passthrough() helper from lem.go - Update export_test.go to use ExportOpts struct Co-Authored-By: Virgil <virgil@lethean.io>
97 lines
2.8 KiB
Go
97 lines
2.8 KiB
Go
package lem
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"forge.lthn.ai/core/go-ml"
|
|
"forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
// AttentionOpts holds configuration for the attention command.
|
|
type AttentionOpts struct {
|
|
Model string // Model config path (relative to .core/ai/models/)
|
|
Prompt string // Prompt text to analyse
|
|
JSON bool // Output as JSON
|
|
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
|
|
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
|
|
Root string // Project root (for .core/ai/ config)
|
|
}
|
|
|
|
// RunAttention is the entry point for `lem score attention`.
|
|
func RunAttention(cfg AttentionOpts) error {
|
|
if cfg.Prompt == "" {
|
|
return fmt.Errorf("--prompt is required")
|
|
}
|
|
|
|
// Load configs.
|
|
aiCfg, err := LoadAIConfig(cfg.Root)
|
|
if err != nil {
|
|
return fmt.Errorf("load ai config: %w", err)
|
|
}
|
|
|
|
modelCfg, err := LoadModelConfig(cfg.Root, cfg.Model)
|
|
if err != nil {
|
|
return fmt.Errorf("load model config: %w", err)
|
|
}
|
|
|
|
// Apply memory limits.
|
|
cacheLimitGB := aiCfg.Distill.CacheLimit
|
|
if cfg.CacheLimit > 0 {
|
|
cacheLimitGB = cfg.CacheLimit
|
|
}
|
|
memLimitGB := aiCfg.Distill.MemoryLimit
|
|
if cfg.MemLimit > 0 {
|
|
memLimitGB = cfg.MemLimit
|
|
}
|
|
if cacheLimitGB > 0 {
|
|
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
|
}
|
|
if memLimitGB > 0 {
|
|
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
|
}
|
|
|
|
// Load model.
|
|
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
|
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
|
if err != nil {
|
|
return fmt.Errorf("load model: %w", err)
|
|
}
|
|
defer backend.Close()
|
|
|
|
// Run attention inspection.
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
|
defer cancel()
|
|
|
|
snap, err := backend.InspectAttention(ctx, cfg.Prompt)
|
|
if err != nil {
|
|
return fmt.Errorf("inspect attention: %w", err)
|
|
}
|
|
|
|
result := AnalyseAttention(snap)
|
|
|
|
if cfg.JSON {
|
|
data, _ := json.MarshalIndent(result, "", " ")
|
|
fmt.Println(string(data))
|
|
return nil
|
|
}
|
|
|
|
// Human-readable output.
|
|
fmt.Printf("Q/K Bone Orientation Analysis\n")
|
|
fmt.Printf(" Architecture: %s\n", snap.Architecture)
|
|
fmt.Printf(" Layers: %d\n", snap.NumLayers)
|
|
fmt.Printf(" KV Heads: %d\n", snap.NumHeads)
|
|
fmt.Printf(" Sequence Length: %d tokens\n", snap.SeqLen)
|
|
fmt.Printf(" Head Dimension: %d\n", snap.HeadDim)
|
|
fmt.Println()
|
|
fmt.Printf(" Mean Coherence: %.3f\n", result.MeanCoherence)
|
|
fmt.Printf(" Cross-Layer Align: %.3f\n", result.MeanCrossAlignment)
|
|
fmt.Printf(" Head Entropy: %.3f\n", result.MeanHeadEntropy)
|
|
fmt.Printf(" Phase-Lock Score: %.3f\n", result.PhaseLockScore)
|
|
fmt.Printf(" Joint Collapses: %d\n", result.JointCollapseCount)
|
|
fmt.Printf(" Composite (0-10000): %d\n", result.Composite())
|
|
return nil
|
|
}
|