feat: add mlx_lm subprocess backend and distill improvements
- Add backend_mlxlm.go blank import to register mlx-lm subprocess backend - Select backend from ai.yaml config (metal, mlx_lm, rocm, api) - Only set Metal cache/memory limits when using metal backend - Add --no-dedup flag to disable grammar-profile deduplication (trained models with consistent voice trigger false positives at 0.02) - Add --context-len flag and context_len config for KV cache sizing - Pass WithBackend and WithContextLen to go-ml backend loader Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
035985f031
commit
d2cf891f15
5 changed files with 63 additions and 28 deletions
|
|
@ -4,8 +4,8 @@ version: 1
|
|||
# Used by: lem distill, lem score, lem chat, lem expand
|
||||
|
||||
# Default inference backend.
|
||||
# Options: metal (go-mlx), rocm (go-rocm), api (OpenAI-compatible HTTP)
|
||||
backend: metal
|
||||
# Options: metal (go-mlx native CGO), mlx_lm (Python subprocess), rocm (go-rocm), api (OpenAI-compatible HTTP)
|
||||
backend: mlx_lm
|
||||
|
||||
# Scorer configuration.
|
||||
scorer:
|
||||
|
|
@ -28,5 +28,6 @@ distill:
|
|||
probes: core # Default probe set from probes.yaml
|
||||
runs: 3 # Generations per probe (best kept)
|
||||
min_chars: 20 # Reject responses shorter than this
|
||||
cache_limit: 8 # Metal cache limit in GB (0 = no limit)
|
||||
memory_limit: 16 # Metal memory limit in GB (0 = no limit)
|
||||
cache_limit: 8 # Metal cache limit in GB (0 = MLX default)
|
||||
memory_limit: 16 # Metal memory limit in GB (0 = MLX default)
|
||||
context_len: 0 # KV cache context window (0 = auto: max_tokens * 2)
|
||||
|
|
|
|||
|
|
@ -22,9 +22,11 @@ func addGenCommands(root *cli.Command) {
|
|||
cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)")
|
||||
cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating")
|
||||
cli.BoolFlag(distillCmd, &distillCfg.NoDedup, "no-dedup", "", false, "Disable grammar-profile deduplication")
|
||||
cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)")
|
||||
cli.IntFlag(distillCmd, &distillCfg.ContextLen, "context-len", "", 0, "KV cache context window (0 = auto: max_tokens * 2)")
|
||||
genGroup.AddCommand(distillCmd)
|
||||
|
||||
// expand — generate expansion responses via trained LEM model.
|
||||
|
|
|
|||
6
pkg/lem/backend_mlxlm.go
Normal file
6
pkg/lem/backend_mlxlm.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
package lem
|
||||
|
||||
// Blank import registers the mlx-lm subprocess backend with go-inference.
|
||||
// This spawns a Python process using mlx-lm for inference — handles memory
|
||||
// management natively via Python's refcounting (2.4 GB vs 17+ GB in CGO).
|
||||
import _ "forge.lthn.ai/core/go-mlx/mlxlm"
|
||||
|
|
@ -39,12 +39,13 @@ type GenerateConfig struct {
|
|||
|
||||
// DistillConfig holds distillation defaults.
|
||||
type DistillConfig struct {
|
||||
Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/)
|
||||
Probes string `yaml:"probes"` // Default probe set name from probes.yaml
|
||||
Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/)
|
||||
Probes string `yaml:"probes"` // Default probe set name from probes.yaml
|
||||
Runs int `yaml:"runs"`
|
||||
MinChars int `yaml:"min_chars"`
|
||||
CacheLimit int `yaml:"cache_limit"` // Metal cache limit in GB (0 = no limit)
|
||||
MemoryLimit int `yaml:"memory_limit"` // Metal memory limit in GB (0 = no limit)
|
||||
ContextLen int `yaml:"context_len"` // KV cache context window (0 = auto: max_tokens * 2)
|
||||
}
|
||||
|
||||
// ModelConfig is a .core/ai/models/{family}/{size}.yaml file.
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-i18n/reversal"
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
"forge.lthn.ai/core/go-ml"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
|
@ -40,9 +41,11 @@ type DistillOpts struct {
|
|||
MinScore float64 // Min grammar composite (0 = use ai.yaml default)
|
||||
Runs int // Generations per probe (0 = use ai.yaml default)
|
||||
DryRun bool // Show plan and exit without generating
|
||||
NoDedup bool // Disable grammar-profile deduplication
|
||||
Root string // Project root (for .core/ai/ config)
|
||||
CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default)
|
||||
MemLimit int // Metal memory limit in GB (0 = use ai.yaml default)
|
||||
ContextLen int // KV cache context window (0 = auto: max_tokens * 2)
|
||||
}
|
||||
|
||||
// RunDistill is the CLI entry point for the distill command.
|
||||
|
|
@ -88,6 +91,14 @@ func RunDistill(cfg DistillOpts) error {
|
|||
if cfg.MemLimit > 0 {
|
||||
memLimitGB = cfg.MemLimit
|
||||
}
|
||||
contextLen := aiCfg.Distill.ContextLen
|
||||
if cfg.ContextLen > 0 {
|
||||
contextLen = cfg.ContextLen
|
||||
}
|
||||
if contextLen == 0 {
|
||||
// Auto: 2x max_tokens covers kernel + prompt + generation.
|
||||
contextLen = genCfg.MaxTokens * 2
|
||||
}
|
||||
|
||||
// Load probes.
|
||||
probeSet := cfg.Probes
|
||||
|
|
@ -148,7 +159,7 @@ func RunDistill(cfg DistillOpts) error {
|
|||
fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore)
|
||||
fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n",
|
||||
genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP)
|
||||
fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB)
|
||||
fmt.Printf("Memory: cache=%dGB limit=%dGB context_len=%d\n", cacheLimitGB, memLimitGB, contextLen)
|
||||
fmt.Printf("Output: %s\n", outputPath)
|
||||
fmt.Println()
|
||||
for i, p := range probes {
|
||||
|
|
@ -165,19 +176,29 @@ func RunDistill(cfg DistillOpts) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// Set Metal memory limits before loading model.
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal cache limit: %dGB", cacheLimitGB)
|
||||
// Set Metal memory limits (only relevant for native Metal backend).
|
||||
backendName := aiCfg.Backend
|
||||
if backendName == "" {
|
||||
backendName = "metal"
|
||||
}
|
||||
if memLimitGB > 0 {
|
||||
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal memory limit: %dGB", memLimitGB)
|
||||
if backendName == "metal" {
|
||||
if cacheLimitGB > 0 {
|
||||
mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal cache limit: %dGB", cacheLimitGB)
|
||||
}
|
||||
if memLimitGB > 0 {
|
||||
mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024)
|
||||
log.Printf("metal memory limit: %dGB", memLimitGB)
|
||||
}
|
||||
}
|
||||
|
||||
// Load model via go-ml Backend (wraps go-inference with memory management).
|
||||
log.Printf("loading model: %s", modelCfg.Paths.Base)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base)
|
||||
// Load model via go-ml Backend (wraps go-inference with backend selection).
|
||||
// WithContextLen bounds the KV cache for native Metal; mlx_lm handles its own memory.
|
||||
log.Printf("loading model: %s (backend=%s, context_len=%d)", modelCfg.Paths.Base, backendName, contextLen)
|
||||
backend, err := ml.NewMLXBackend(modelCfg.Paths.Base,
|
||||
inference.WithBackend(backendName),
|
||||
inference.WithContextLen(contextLen),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load model: %w", err)
|
||||
}
|
||||
|
|
@ -300,11 +321,13 @@ func RunDistill(cfg DistillOpts) error {
|
|||
// Quality gate.
|
||||
if best != nil && best.Grammar.Composite >= minScore {
|
||||
// Duplicate filter: reject if grammar profile is too similar to an already-kept entry.
|
||||
bestFeatures := GrammarFeatures(best.Grammar)
|
||||
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
|
||||
deduped++
|
||||
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
|
||||
continue
|
||||
if !cfg.NoDedup {
|
||||
bestFeatures := GrammarFeatures(best.Grammar)
|
||||
if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) {
|
||||
deduped++
|
||||
fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Save with output prompt — sandwich if kernel, bare if LEM model.
|
||||
|
|
@ -317,12 +340,14 @@ func RunDistill(cfg DistillOpts) error {
|
|||
line, _ := json.Marshal(example)
|
||||
out.Write(append(line, '\n'))
|
||||
|
||||
// Add to dedup index.
|
||||
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
|
||||
if dedupIdx == nil {
|
||||
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
|
||||
} else {
|
||||
_ = dedupIdx.Insert(entry)
|
||||
// Add to dedup index (unless disabled).
|
||||
if !cfg.NoDedup {
|
||||
entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar}
|
||||
if dedupIdx == nil {
|
||||
dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry})
|
||||
} else {
|
||||
_ = dedupIdx.Insert(entry)
|
||||
}
|
||||
}
|
||||
|
||||
kept++
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue