Commit graph

31 commits

Author SHA1 Message Date
Claude
694e78ca34
chore: sort.Slice → slices.SortFunc
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 16:29:48 +00:00
Claude
9f6dd9d4eb
chore: fmt.Errorf(static) → errors.New
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 16:28:27 +00:00
Snider
51ac442a09 fix: add deterministic GPU memory cleanup across inference paths
Some checks failed
Security Scan / security (push) Successful in 15s
Test / Vet & Build (push) Failing after 32s
- defer freeCaches() in Generate and InspectAttention
- Free orphaned arrays during KVCache growth and slice updates
- Free per-token scalar intermediates in samplers and ops
- Free intermediate arrays in applyRepeatPenalty

Found by 3-way review: Claude explorer, Codex (gpt-5.3), Gemini Ultra.
Gemini implemented the fixes.

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 05:08:02 +00:00
Snider
c2177f754a feat: implement AttentionInspector via KV cache extraction after prefill
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 00:37:29 +00:00
Snider
5004ac258a refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-22 21:00:16 +00:00
Snider
ceb966b66b feat(metal): expose model metadata via Info()
Return architecture, vocab size, layer count, hidden dimension, and
quantisation config (bits + group size) for loaded models.

Gemma3-1B 4-bit: arch=gemma3, vocab=262144, layers=26, hidden=1152,
quant=4-bit/group64.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:36:23 +00:00
Snider
a44e9f5789 feat(metal): add inference metrics (timing, throughput, memory)
Instrument Generate, Classify, and BatchGenerate with:
- Prefill/decode timing (separate phases)
- Token counts (prompt + generated)
- Throughput (tok/s for each phase)
- Peak and active GPU memory via Metal allocator

Wire through metalAdapter.Metrics() to go-inference interface.
Test validates all fields populated after generation.

Gemma3-1B 4-bit on M3 Ultra: prefill 246 tok/s, decode 82 tok/s,
peak 6.2 GB GPU memory.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:34:40 +00:00
Snider
5644857034 feat(metal): implement batch inference (Classify, BatchGenerate)
- Add ForwardMasked to InternalModel, Gemma3 and Qwen3 architectures
- Thread attention mask through decoder layers and SDPA calls
- Use ScaledDotProductAttentionWithMask when explicit mask provided
- Create batch.go with padded batching, mask construction, Classify
  (prefill-only) and BatchGenerate (autoregressive) implementations
- Wire Classify/BatchGenerate through metalAdapter to go-inference
- Tests: mask unit tests (shape, values, multi-batch), Classify with
  4 prompts (152 prompts/s), WithLogits, BatchGenerate with 2 prompts

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:28:15 +00:00
Snider
e3fbc221ce feat(metal): add mixed precision training via LoRAConfig.DType (Phase 3)
LoRA A/B matrices can now be created in BFloat16 or Float16 for mixed
precision training. DType field added to LoRAConfig, passed through
ApplyLoRA and NewLoRALinear. MLX auto-promotes for cross-dtype ops.
BFloat16 validated: loss 7.15→6.29, matches Float32 accuracy with
half param memory.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:13:49 +00:00
Snider
fa08ed1e2a test(metal): validate gradient checkpointing with real model (Phase 3)
Checkpoint() wraps forward pass to recompute activations during
backward, trading compute for memory. Verified with Gemma3-1B LoRA
training: produces correct gradients (loss 7.15→7.08, matches
non-checkpointed initial loss). Unit test confirms gradient
correctness on simple function (sum(x^2), grad=[2,4,6]).

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:11:15 +00:00
Snider
fb0692baf3 test(metal): add LoRA end-to-end training pipeline test (Phase 3)
Validates full pipeline: load Gemma3-1B → apply LoRA (rank=8, 745K
params across 52 layers) → train 5 steps with cross-entropy loss
(7.15→6.31) → save adapter to safetensors → reload and verify all
weights match. Uses ValueAndGrad for autograd + AdamW optimiser.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:09:16 +00:00
Snider
19c4823b04 feat(metal): add Llama 3 model support (Llama 3.1 8B validated)
Llama shares the Qwen3 loader (same decoder: pre-norm, SwiGLU, GQA).
Model type now detected from config.json model_type field instead of
weight-only heuristic. Llama 3 chat template and EOS token added.
Model tests now clear Metal GPU cache between runs.

Llama 3.1 8B Instruct 4-bit: 30 tok/s on M3 Ultra.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 23:06:43 +00:00
Snider
535b04d5d6 feat(metal): add Qwen2 model support (DeepSeek R1 validated)
Qwen2 and Qwen3 share the same architecture — Qwen3 adds Q/K RMS
normalization which Qwen2 lacks. The loader auto-detects the variant
from weight presence and reports the correct ModelType().

- Add "qwen2" to architecture dispatch in model.go
- Make Q/K norm optional in attention forward (nil-safe check)
- Store detected model type on Qwen3Model struct
- Add "qwen2" to chat template routing
- DeepSeek R1 7B (4-bit): 27 tok/s on M3 Ultra
- 2 new tests: inference + chat

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:55:56 +00:00
Snider
a2493e0242 test(metal): add model loading robustness tests (Phase 2)
24 new tests covering error paths in model loading:
- Missing/invalid config.json, unsupported architecture
- Missing tokenizer.json for both Gemma3 and Qwen3
- Missing safetensors: was a nil-pointer panic in precomputeScaledWeights,
  fixed with early error return in both LoadGemma3 and LoadQwen3
- Config parsing: defaults, quantization, nested text_config
- isLayerSliding sliding window pattern logic
- resolveWeight with language_model. prefix fallback

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:49:07 +00:00
Snider
18e8dca9f8 feat(metal): validate Gemma3-1B inference end-to-end (Phase 2)
- Fix model_type "gemma3_text" not matched in architecture dispatch
- Fix GPT-2 BPE false detection on large SentencePiece vocabs (Gemma3
  262K vocab contains Ġ but uses ▁ for spaces — check "Ġthe" not bare "Ġ")
- Add TestGemma3_1B_Inference: greedy decode, 46 tok/s, coherent output
- Add TestGemma3_1B_Chat: validates chat template formatting
- Add TestGemma3_1B_ContextCancel: validates ctx.Done() stops generation

4-bit quantised Gemma3-1B loads in ~700ms, generates at 46 tok/s on M3 Ultra.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:44:28 +00:00
Snider
443347a2f8 fix(metal): address 4 minor code review items
- Rename New() → newArray() to signal internal-only intent (112 usages)
- Remove unused Collect() function and its test
- Fix discarded json.Unmarshal error in qwen3.go
- Document AsStrided stride formula in gemma3.go

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:36:40 +00:00
Snider
fb95cde30c fix(metal): address 5 important code review items
1. RepeatPenalty: implemented applyRepeatPenalty() — tracks generated
   token IDs, deduplicates, divides positive logits by penalty and
   multiplies negative logits by penalty. 2 new tests.

2. DefaultGPUStream/DefaultCPUStream: now cached with sync.Once,
   no more C stream allocation on every call.

3. CompileShapeless: removed dead C closure, callback, sync.Map,
   and nextID infrastructure. CompiledFunc is now a plain function
   wrapper with mutex. API unchanged.

4. Tokenizer BPE: implemented bpeMerge() — standard BPE algorithm
   using merge rank lookup. Both SentencePiece and GPT-2 Encode paths
   now apply merges instead of falling back to character-level lookup.
   3 new tests.

5. KV cache lifecycle: documented in Generate() godoc — fresh caches
   per call, ClearCache() between turns for prompt Metal reclaim.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:31:45 +00:00
Snider
c96f9bd006 fix(metal): address 3 critical code review items
1. Error handler thread safety: last_mlx_error now uses _Atomic(const char*)
   with atomic_store_explicit/atomic_exchange_explicit (release/acquire).

2. macOS version minimum: -mmacosx-version-min changed from 26.0 to 13.3
   (MLX's own minimum), no longer locks out macOS 14/15 users.

3. LoadOption applied in metalBackend.LoadModel(): calls ApplyLoadOpts(),
   passes ContextLen through to Model which replaces unbounded KVCache
   with RotatingKVCache when set. GPULayers=0 logs a warning.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:24:10 +00:00
Snider
f13a8c9289 feat(metal): deterministic Close() and Array.Iter()
Model.Close() now walks the full model tree (Gemma3/Qwen3) and
explicitly frees all weight arrays. Handles tied output weights,
nil safety, idempotent double-close. Helpers: freeLinear,
freeEmbedding, freeRMSNorm, freeCaches, closeGemma, closeQwen3.

Array.Iter() returns iter.Seq[float32] for range-over-func iteration.
Handles non-contiguous arrays and supports early break.

192 tests passing (12 new: 8 close, 4 iter).

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 21:09:39 +00:00
Snider
754d6e2f93 fix(metal): error handling audit — propagate MLX errors instead of swallowing
Replace checkError() log+swallow with lastError() that returns real MLX
error messages. Add Eval/EvalAsync as error-returning variants of
Materialize. Generate loop now propagates GPU errors via model.Err().
LoadAllSafetensors returns (map, error). Model loaders check lastError()
after safetensors load. 180 tests passing.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 20:59:37 +00:00
Snider
ff01175a62 bench(metal): add 29 benchmarks baselined on M3 Ultra
MatMul (128² to 4096², token projection), Softmax, element-wise
ops, fused Metal kernels (RMSNorm, LayerNorm, RoPE, SDPA), Linear,
Embedding, reductions, and full sampler chain. CGO floor ~170μs.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 20:47:25 +00:00
Snider
ca6b16eaf2 feat(metal): bind memory diagnostics and device info
New bindings from mlx-c memory.h and metal.h:
- GetCacheMemory() — current allocator cache size
- ResetPeakMemory() — reset high-water mark
- SetWiredLimit() — control wired memory limit
- GetDeviceInfo() — GPU architecture, max buffer, memory size

All exposed at root package level via register_metal.go delegates.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:39:51 +00:00
Snider
f39126f6bd feat(metal): bind CumSum, implement TopP and MinP sampling
New ops: CumSum, Sort, Argsort, Greater, MaxAxis — all bound to mlx-c.

TopP (nucleus) sampling now fully implemented: sorts probabilities
descending, computes cumulative sum, masks tokens beyond the threshold,
and scatters the mask back to original positions via argsort.

MinP sampling now fully implemented: computes softmax, finds max
probability, masks tokens below min_p * max_prob.

Both were previously stubs that passed through logits unchanged.

10 new tests (CumSum variants, Sort, Argsort, Greater, MaxAxis,
TopP, MinP). 176 total tests passing.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:39:44 +00:00
Snider
df0b300b1a fix(metal): auto-contiguous data access for non-contiguous arrays
Bind mlx_contiguous and _mlx_array_is_row_contiguous from mlx-c.
Floats(), DataInt32(), and Ints() now automatically handle non-contiguous
arrays (from Transpose, BroadcastTo, SliceAxis, etc.) by checking
IsRowContiguous() and making a contiguous copy when needed.

Previously these methods returned silently wrong data for view arrays.
The old workaround of Reshape(arr, totalSize) is no longer needed.

7 new tests for contiguous handling (transpose, broadcast, slice views).

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:39:36 +00:00
Snider
bff97ccf19 feat(api): migrate to go-inference shared interfaces
Replace local TextModel, Backend, Token, Message, and option types with
forge.lthn.ai/core/go-inference. go-mlx is now a pure backend that
registers "metal" into the shared inference registry via init().

Deleted: textmodel.go, options.go, backend.go
Updated: register_metal.go (implements inference.Backend with Available()),
  mlx_test.go (uses inference.* types, 4 new tests), go.mod,
  internal/metal/generate.go (added RepeatPenalty)

159 tests passing (148 internal/metal + 11 root).

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:15:42 +00:00
Snider
4d1bff3d78 refactor(api): clean root package — interfaces only, metal auto-registered
Root package now contains only:
- mlx.go: package doc + go:generate directives
- textmodel.go: TextModel, Token, Message interfaces
- options.go: GenerateOption, LoadOption functional options
- backend.go: Backend interface, Register/Get/Default/LoadModel
- register_metal.go: build-tagged init() + adapter + memory delegates
- mlx_stub.go: non-darwin fallback

internal/metal/ has its own Token, GenerateConfig, Model types.
register_metal.go adapts between the two via metalAdapter.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:04:19 +00:00
Snider
c612c3e060 refactor(metal): move all tests to internal/metal (148 tests passing)
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 20:00:02 +00:00
Snider
08976aa504 refactor(metal): flatten model, tokenizer, sample, cache into internal/metal
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 19:51:14 +00:00
Snider
a669d1d9c1 refactor(metal): move nn, io, grad, lora, optim to internal/metal
Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 19:39:58 +00:00
Snider
d6a49544bd refactor(metal): move ops, slice, random, fast, compile to internal/metal
Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 19:39:49 +00:00
Snider
1cf5178c80 refactor(metal): move dtype, array, metal, stream to internal/metal
Move foundation CGO files from root package to internal/metal/ package.
Changes package declaration from `package mlx` to `package metal`.
Updates CGO SRCDIR paths to account for new location (two levels deeper).
Extracts go:generate directives into root generate.go.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 19:34:38 +00:00