1
0
Fork 0
forked from lthn/LEM

feat: add mlx_lm subprocess backend and distill improvements

- Add backend_mlxlm.go blank import to register mlx-lm subprocess backend
- Select backend from ai.yaml config (metal, mlx_lm, rocm, api)
- Only set Metal cache/memory limits when using metal backend
- Add --no-dedup flag to disable grammar-profile deduplication
  (trained models with consistent voice trigger false positives at 0.02)
- Add --context-len flag and context_len config for KV cache sizing
- Pass WithBackend and WithContextLen to go-ml backend loader

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 18:37:12 +00:00
parent 035985f031
commit d2cf891f15
5 changed files with 63 additions and 28 deletions

View file

@ -4,8 +4,8 @@ version: 1
# Used by: lem distill, lem score, lem chat, lem expand
# Default inference backend.
# Options: metal (go-mlx), rocm (go-rocm), api (OpenAI-compatible HTTP)
backend: metal
# Options: metal (go-mlx native CGO), mlx_lm (Python subprocess), rocm (go-rocm), api (OpenAI-compatible HTTP)
backend: mlx_lm
# Scorer configuration.
scorer:
@ -28,5 +28,6 @@ distill:
probes: core # Default probe set from probes.yaml
runs: 3 # Generations per probe (best kept)
min_chars: 20 # Reject responses shorter than this
cache_limit: 8 # Metal cache limit in GB (0 = no limit)
memory_limit: 16 # Metal memory limit in GB (0 = no limit)
cache_limit: 8 # Metal cache limit in GB (0 = MLX default)
memory_limit: 16 # Metal memory limit in GB (0 = MLX default)
context_len: 0 # KV cache context window (0 = auto: max_tokens * 2)

View file

@ -22,9 +22,11 @@ func addGenCommands(root *cli.Command) {
cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)")
cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)")
cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating")
cli.BoolFlag(distillCmd, &distillCfg.NoDedup, "no-dedup", "", false, "Disable grammar-profile deduplication")
cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
cli.IntFlag(distillCmd, &distillCfg.ContextLen, "context-len", "", 0, "KV cache context window (0 = auto: max_tokens * 2)")
genGroup.AddCommand(distillCmd)
// expand — generate expansion responses via trained LEM model.

6
pkg/lem/backend_mlxlm.go Normal file
View file

@ -0,0 +1,6 @@
package lem
// Blank import registers the mlx-lm subprocess backend with go-inference.
// This spawns a Python process using mlx-lm for inference — handles memory
// management natively via Python's refcounting (2.4 GB vs 17+ GB in CGO).
import _ "forge.lthn.ai/core/go-mlx/mlxlm"

View file

@ -39,12 +39,13 @@ type GenerateConfig struct {
// DistillConfig holds distillation defaults.
type DistillConfig struct {
Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/)
Probes string `yaml:"probes"` // Default probe set name from probes.yaml
Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/)
Probes string `yaml:"probes"` // Default probe set name from probes.yaml
Runs int `yaml:"runs"`
MinChars int `yaml:"min_chars"`
CacheLimit int `yaml:"cache_limit"` // Metal cache limit in GB (0 = no limit)
MemoryLimit int `yaml:"memory_limit"` // Metal memory limit in GB (0 = no limit)
ContextLen int `yaml:"context_len"` // KV cache context window (0 = auto: max_tokens * 2)
}
// ModelConfig is a .core/ai/models/{family}/{size}.yaml file.

View file

@ -11,6 +11,7 @@ import (
"time"
"forge.lthn.ai/core/go-i18n/reversal"
"forge.lthn.ai/core/go-inference"
"forge.lthn.ai/core/go-ml"
"forge.lthn.ai/core/go-mlx"
)
@ -40,9 +41,11 @@ type DistillOpts struct {
MinScore float64 // Min grammar composite (0 = use ai.yaml default)
Runs int // Generations per probe (0 = use ai.yaml default)
DryRun bool // Show plan and exit without generating
NoDedup bool // Disable grammar-profile deduplication
Root string // Project root (for .core/ai/ config)
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
ContextLen int // KV cache context window (0 = auto: max_tokens * 2)
}
// RunDistill is the CLI entry point for the distill command.
@ -88,6 +91,14 @@ func RunDistill(cfg DistillOpts) error {
if cfg.MemLimit > 0 {
memLimitGB = cfg.MemLimit
}
contextLen := aiCfg.Distill.ContextLen
if cfg.ContextLen > 0 {
contextLen = cfg.ContextLen
}
if contextLen == 0 {
// Auto: 2x max_tokens covers kernel + prompt + generation.
contextLen = genCfg.MaxTokens * 2
}
// Load probes.
probeSet := cfg.Probes
@ -148,7 +159,7 @@ func RunDistill(cfg DistillOpts) error {
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore)
fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n",
genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP)
fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB)
fmt.Printf("Memory: cache=%dGB limit=%dGB context_len=%d\n", cacheLimitGB, memLimitGB, contextLen)
fmt.Printf("Output: %s\n", outputPath)
fmt.Println()
for i, p := range probes {
@ -165,19 +176,29 @@ func RunDistill(cfg DistillOpts) error {
return nil
}
// Set Metal memory limits before loading model.
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal cache limit: %dGB", cacheLimitGB)
// Set Metal memory limits (only relevant for native Metal backend).
backendName := aiCfg.Backend
if backendName == "" {
backendName = "metal"
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal memory limit: %dGB", memLimitGB)
if backendName == "metal" {
if cacheLimitGB > 0 {
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal cache limit: %dGB", cacheLimitGB)
}
if memLimitGB > 0 {
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
log.Printf("metal memory limit: %dGB", memLimitGB)
}
}
// Load model via go-ml Backend (wraps go-inference with memory management).
log.Printf("loading model: %s", modelCfg.Paths.Base)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
// Load model via go-ml Backend (wraps go-inference with backend selection).
// WithContextLen bounds the KV cache for native Metal; mlx_lm handles its own memory.
log.Printf("loading model: %s (backend=%s, context_len=%d)", modelCfg.Paths.Base, backendName, contextLen)
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base,
inference.WithBackend(backendName),
inference.WithContextLen(contextLen),
)
if err != nil {
return fmt.Errorf("load model: %w", err)
}
@ -300,11 +321,13 @@ func RunDistill(cfg DistillOpts) error {
// Quality gate.
if best != nil && best.Grammar.Composite >= minScore {
// Duplicate filter: reject if grammar profile is too similar to an already-kept entry.
bestFeatures := GrammarFeatures(best.Grammar)
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
deduped++
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
continue
if !cfg.NoDedup {
bestFeatures := GrammarFeatures(best.Grammar)
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
deduped++
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
continue
}
}
// Save with output prompt — sandwich if kernel, bare if LEM model.
@ -317,12 +340,14 @@ func RunDistill(cfg DistillOpts) error {
line, _ := json.Marshal(example)
out.Write(append(line, '\n'))
// Add to dedup index.
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
if dedupIdx == nil {
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
} else {
_ = dedupIdx.Insert(entry)
// Add to dedup index (unless disabled).
if !cfg.NoDedup {
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
if dedupIdx == nil {
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
} else {
_ = dedupIdx.Insert(entry)
}
}
kept++