diff --git a/.core/ai/ai.yaml b/.core/ai/ai.yaml index 7bc3400..59e3334 100644 --- a/.core/ai/ai.yaml +++ b/.core/ai/ai.yaml @@ -4,8 +4,8 @@ version: 1 # Used by: lem distill, lem score, lem chat, lem expand # Default inference backend. -# Options: metal (go-mlx), rocm (go-rocm), api (OpenAI-compatible HTTP) -backend: metal +# Options: metal (go-mlx native CGO), mlx_lm (Python subprocess), rocm (go-rocm), api (OpenAI-compatible HTTP) +backend: mlx_lm # Scorer configuration. scorer: @@ -28,5 +28,6 @@ distill: probes: core # Default probe set from probes.yaml runs: 3 # Generations per probe (best kept) min_chars: 20 # Reject responses shorter than this - cache_limit: 8 # Metal cache limit in GB (0 = no limit) - memory_limit: 16 # Metal memory limit in GB (0 = no limit) + cache_limit: 8 # Metal cache limit in GB (0 = MLX default) + memory_limit: 16 # Metal memory limit in GB (0 = MLX default) + context_len: 0 # KV cache context window (0 = auto: max_tokens * 2) diff --git a/cmd/lemcmd/gen.go b/cmd/lemcmd/gen.go index 7a2cd5b..d177c7d 100644 --- a/cmd/lemcmd/gen.go +++ b/cmd/lemcmd/gen.go @@ -22,9 +22,11 @@ func addGenCommands(root *cli.Command) { cli.Float64Flag(distillCmd, &distillCfg.MinScore, "min-score", "", 0, "Min grammar composite (0 = use ai.yaml default)") cli.IntFlag(distillCmd, &distillCfg.Runs, "runs", "r", 0, "Generations per probe (0 = use ai.yaml default)") cli.BoolFlag(distillCmd, &distillCfg.DryRun, "dry-run", "", false, "Show plan and exit without generating") + cli.BoolFlag(distillCmd, &distillCfg.NoDedup, "no-dedup", "", false, "Disable grammar-profile deduplication") cli.StringFlag(distillCmd, &distillCfg.Root, "root", "", ".", "Project root (for .core/ai/ config)") cli.IntFlag(distillCmd, &distillCfg.CacheLimit, "cache-limit", "", 0, "Metal cache limit in GB (0 = use ai.yaml default)") cli.IntFlag(distillCmd, &distillCfg.MemLimit, "mem-limit", "", 0, "Metal memory limit in GB (0 = use ai.yaml default)") + cli.IntFlag(distillCmd, &distillCfg.ContextLen, "context-len", "", 0, "KV cache context window (0 = auto: max_tokens * 2)") genGroup.AddCommand(distillCmd) // expand — generate expansion responses via trained LEM model. diff --git a/pkg/lem/backend_mlxlm.go b/pkg/lem/backend_mlxlm.go new file mode 100644 index 0000000..5ec3398 --- /dev/null +++ b/pkg/lem/backend_mlxlm.go @@ -0,0 +1,6 @@ +package lem + +// Blank import registers the mlx-lm subprocess backend with go-inference. +// This spawns a Python process using mlx-lm for inference — handles memory +// management natively via Python's refcounting (2.4 GB vs 17+ GB in CGO). +import _ "forge.lthn.ai/core/go-mlx/mlxlm" diff --git a/pkg/lem/config.go b/pkg/lem/config.go index e1fe4aa..6e87be0 100644 --- a/pkg/lem/config.go +++ b/pkg/lem/config.go @@ -39,12 +39,13 @@ type GenerateConfig struct { // DistillConfig holds distillation defaults. type DistillConfig struct { - Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/) - Probes string `yaml:"probes"` // Default probe set name from probes.yaml + Model string `yaml:"model"` // Default model config path (relative to .core/ai/models/) + Probes string `yaml:"probes"` // Default probe set name from probes.yaml Runs int `yaml:"runs"` MinChars int `yaml:"min_chars"` CacheLimit int `yaml:"cache_limit"` // Metal cache limit in GB (0 = no limit) MemoryLimit int `yaml:"memory_limit"` // Metal memory limit in GB (0 = no limit) + ContextLen int `yaml:"context_len"` // KV cache context window (0 = auto: max_tokens * 2) } // ModelConfig is a .core/ai/models/{family}/{size}.yaml file. diff --git a/pkg/lem/distill.go b/pkg/lem/distill.go index b2b9621..2be4938 100644 --- a/pkg/lem/distill.go +++ b/pkg/lem/distill.go @@ -11,6 +11,7 @@ import ( "time" "forge.lthn.ai/core/go-i18n/reversal" + "forge.lthn.ai/core/go-inference" "forge.lthn.ai/core/go-ml" "forge.lthn.ai/core/go-mlx" ) @@ -40,9 +41,11 @@ type DistillOpts struct { MinScore float64 // Min grammar composite (0 = use ai.yaml default) Runs int // Generations per probe (0 = use ai.yaml default) DryRun bool // Show plan and exit without generating + NoDedup bool // Disable grammar-profile deduplication Root string // Project root (for .core/ai/ config) CacheLimit int // Metal cache limit in GB (0 = use ai.yaml default) MemLimit int // Metal memory limit in GB (0 = use ai.yaml default) + ContextLen int // KV cache context window (0 = auto: max_tokens * 2) } // RunDistill is the CLI entry point for the distill command. @@ -88,6 +91,14 @@ func RunDistill(cfg DistillOpts) error { if cfg.MemLimit > 0 { memLimitGB = cfg.MemLimit } + contextLen := aiCfg.Distill.ContextLen + if cfg.ContextLen > 0 { + contextLen = cfg.ContextLen + } + if contextLen == 0 { + // Auto: 2x max_tokens covers kernel + prompt + generation. + contextLen = genCfg.MaxTokens * 2 + } // Load probes. probeSet := cfg.Probes @@ -148,7 +159,7 @@ func RunDistill(cfg DistillOpts) error { fmt.Printf("Gate: grammar v3 composite >= %.1f\n", minScore) fmt.Printf("Generate: temp=%.2f max_tokens=%d top_p=%.2f\n", genCfg.Temperature, genCfg.MaxTokens, genCfg.TopP) - fmt.Printf("Memory: cache=%dGB limit=%dGB\n", cacheLimitGB, memLimitGB) + fmt.Printf("Memory: cache=%dGB limit=%dGB context_len=%d\n", cacheLimitGB, memLimitGB, contextLen) fmt.Printf("Output: %s\n", outputPath) fmt.Println() for i, p := range probes { @@ -165,19 +176,29 @@ func RunDistill(cfg DistillOpts) error { return nil } - // Set Metal memory limits before loading model. - if cacheLimitGB > 0 { - mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024) - log.Printf("metal cache limit: %dGB", cacheLimitGB) + // Set Metal memory limits (only relevant for native Metal backend). + backendName := aiCfg.Backend + if backendName == "" { + backendName = "metal" } - if memLimitGB > 0 { - mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024) - log.Printf("metal memory limit: %dGB", memLimitGB) + if backendName == "metal" { + if cacheLimitGB > 0 { + mlx.SetCacheLimit(uint64(cacheLimitGB) * 1024 * 1024 * 1024) + log.Printf("metal cache limit: %dGB", cacheLimitGB) + } + if memLimitGB > 0 { + mlx.SetMemoryLimit(uint64(memLimitGB) * 1024 * 1024 * 1024) + log.Printf("metal memory limit: %dGB", memLimitGB) + } } - // Load model via go-ml Backend (wraps go-inference with memory management). - log.Printf("loading model: %s", modelCfg.Paths.Base) - backend, err := ml.NewMLXBackend(modelCfg.Paths.Base) + // Load model via go-ml Backend (wraps go-inference with backend selection). + // WithContextLen bounds the KV cache for native Metal; mlx_lm handles its own memory. + log.Printf("loading model: %s (backend=%s, context_len=%d)", modelCfg.Paths.Base, backendName, contextLen) + backend, err := ml.NewMLXBackend(modelCfg.Paths.Base, + inference.WithBackend(backendName), + inference.WithContextLen(contextLen), + ) if err != nil { return fmt.Errorf("load model: %w", err) } @@ -300,11 +321,13 @@ func RunDistill(cfg DistillOpts) error { // Quality gate. if best != nil && best.Grammar.Composite >= minScore { // Duplicate filter: reject if grammar profile is too similar to an already-kept entry. - bestFeatures := GrammarFeatures(best.Grammar) - if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) { - deduped++ - fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID) - continue + if !cfg.NoDedup { + bestFeatures := GrammarFeatures(best.Grammar) + if dedupIdx != nil && dedupIdx.IsDuplicate(bestFeatures, 0.02) { + deduped++ + fmt.Fprintf(os.Stderr, " ~ DEDUP %s (grammar profile too similar to existing)\n", probe.ID) + continue + } } // Save with output prompt — sandwich if kernel, bare if LEM model. @@ -317,12 +340,14 @@ func RunDistill(cfg DistillOpts) error { line, _ := json.Marshal(example) out.Write(append(line, '\n')) - // Add to dedup index. - entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar} - if dedupIdx == nil { - dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry}) - } else { - _ = dedupIdx.Insert(entry) + // Add to dedup index (unless disabled). + if !cfg.NoDedup { + entry := ScoredEntry{ID: probe.ID, Domain: probe.Domain, Grammar: best.Grammar} + if dedupIdx == nil { + dedupIdx, _ = NewScoreIndex([]ScoredEntry{entry}) + } else { + _ = dedupIdx.Insert(entry) + } } kept++