From cae7ef05e86969902a7fa4f20906486e7dd5f5dc Mon Sep 17 00:00:00 2001 From: Snider Date: Thu, 19 Feb 2026 17:57:37 +0000 Subject: [PATCH] feat: extract go-mlx from go-ai as standalone Metal inference package MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split mlx/ directory from forge.lthn.ai/core/go-ai into its own module. Rewrites import paths, adds CLAUDE.md/TODO.md/FINDINGS.md for dedicated Claude sessions. Zero external Go deps — pure CGO + mlx-c v0.4.1. Co-Authored-By: Virgil --- .gitignore | 16 ++ CLAUDE.md | 131 ++++++++++++ CMakeLists.txt | 28 +++ FINDINGS.md | 116 +++++++++++ TODO.md | 44 ++++ array.go | 261 ++++++++++++++++++++++++ cache/cache.go | 201 ++++++++++++++++++ compile.go | 90 +++++++++ dtype.go | 83 ++++++++ fast.go | 79 ++++++++ go.mod | 3 + grad.go | 353 ++++++++++++++++++++++++++++++++ grad_test.go | 332 ++++++++++++++++++++++++++++++ io.go | 63 ++++++ lora.go | 237 ++++++++++++++++++++++ lora_test.go | 316 +++++++++++++++++++++++++++++ mlx.go | 115 +++++++++++ mlx_stub.go | 10 + model/gemma3.go | 448 +++++++++++++++++++++++++++++++++++++++++ model/model.go | 78 +++++++ model/qwen3.go | 337 +++++++++++++++++++++++++++++++ nn.go | 115 +++++++++++ ops.go | 353 ++++++++++++++++++++++++++++++++ optim.go | 106 ++++++++++ optim_test.go | 175 ++++++++++++++++ random.go | 46 +++++ sample/sample.go | 90 +++++++++ slice.go | 63 ++++++ stream.go | 79 ++++++++ tokenizer/tokenizer.go | 324 +++++++++++++++++++++++++++++ 30 files changed, 4692 insertions(+) create mode 100644 .gitignore create mode 100644 CLAUDE.md create mode 100644 CMakeLists.txt create mode 100644 FINDINGS.md create mode 100644 TODO.md create mode 100644 array.go create mode 100644 cache/cache.go create mode 100644 compile.go create mode 100644 dtype.go create mode 100644 fast.go create mode 100644 go.mod create mode 100644 grad.go create mode 100644 grad_test.go create mode 100644 io.go create mode 100644 lora.go create mode 100644 lora_test.go create mode 100644 mlx.go create mode 100644 mlx_stub.go create mode 100644 model/gemma3.go create mode 100644 model/model.go create mode 100644 model/qwen3.go create mode 100644 nn.go create mode 100644 ops.go create mode 100644 optim.go create mode 100644 optim_test.go create mode 100644 random.go create mode 100644 sample/sample.go create mode 100644 slice.go create mode 100644 stream.go create mode 100644 tokenizer/tokenizer.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e34251e --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +# Build artifacts +build/ +*.dylib +*.so +*.a + +# CMake +CMakeCache.txt +CMakeFiles/ +cmake_install.cmake +Makefile + +# IDE +.idea/ +.vscode/ +*.swp diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..4de07f7 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,131 @@ +# CLAUDE.md + +## What This Is + +Native Apple Metal GPU inference via mlx-c bindings. Module: `forge.lthn.ai/core/go-mlx` + +Pure Go + CGO package that wraps Apple's [MLX framework](https://github.com/ml-explore/mlx) through the [mlx-c](https://github.com/ml-explore/mlx-c) C API. Runs LLM inference on Apple Silicon GPUs (M1-M4) using Metal compute shaders. + +## Platform + +**darwin/arm64 only.** All files carry `//go:build darwin && arm64`. A stub (`mlx_stub.go`) provides `MetalAvailable() bool` returning false on other platforms. + +## Build + +```bash +# Step 1: Build mlx-c C library via CMake (fetches mlx-c v0.4.1) +cd /tmp/core-go-mlx && go generate ./... + +# Step 2: Run tests (must be on Apple Silicon) +go test ./... + +# Step 3: Use in another module +# Import "forge.lthn.ai/core/go-mlx" and "forge.lthn.ai/core/go-mlx/model" +``` + +### CGO Flags (auto-set via go:generate) + +CMake installs to `dist/` inside the package directory. The `#cgo` directives in `mlx.go` reference: +- `CPPFLAGS: -I${SRCDIR}/dist/include` +- `LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx` +- Frameworks: Foundation, Metal, Accelerate + +### CMake Config + +`CMakeLists.txt` fetches mlx-c v0.4.1 from GitHub. Key settings: +- `MLX_BUILD_SAFETENSORS=ON` (model loading) +- `MLX_BUILD_GGUF=OFF` (GGUF is for llama.cpp backend in go-ai/ml/) +- `BUILD_SHARED_LIBS=ON` +- macOS deployment target: 26.0 + +## Architecture + +``` +go-mlx/ +├── Core Layer (Metal GPU) +│ ├── mlx.go — Init, Materialize, MetalAvailable (CGO bridge) +│ ├── mlx_stub.go — Non-Apple fallback +│ ├── array.go — MLX array wrapper (create, reshape, data access) +│ ├── dtype.go — Data types (Float16, Float32, BFloat16, Int32, etc.) +│ ├── stream.go — Metal stream/queue management +│ └── CMakeLists.txt — Fetches mlx-c v0.4.1 +│ +├── Operations +│ ├── ops.go — Element-wise: Add, Multiply, MatMul, Softmax, etc. +│ ├── fast.go — Fused Metal kernels: RMSNorm, RoPE, ScaledDotProductAttention +│ ├── nn.go — Neural network layers: Linear, Embedding, RMSNorm +│ ├── compile.go — Compiled function closures for kernel fusion +│ ├── slice.go — Array slicing/indexing +│ └── random.go — Categorical sampling +│ +├── Training +│ ├── grad.go — VJP (Vector-Jacobian Product) gradient computation +│ ├── lora.go — LoRA adapter (rank decomposition fine-tuning) +│ └── optim.go — AdamW optimiser +│ +├── Model Support +│ ├── model/ +│ │ ├── model.go — Base model interface (LoadModel, Generate) +│ │ ├── gemma3.go — Google Gemma3 decoder +│ │ └── qwen3.go — Alibaba Qwen3 decoder +│ ├── tokenizer/ +│ │ └── tokenizer.go — BPE tokenizer (sentencepiece format) +│ ├── sample/ +│ │ └── sample.go — Sampling strategies (temperature, top-k, top-p) +│ └── cache/ +│ └── cache.go — KV cache for autoregressive inference +│ +└── I/O + └── io.go — Safetensors model loading +``` + +## Key Interfaces + +```go +// model/model.go — base interface for all models +type Model interface { + Forward(x *mlx.Array, cache *cache.KVCache) *mlx.Array +} + +// Top-level — GPU materialisation +mlx.Materialize(outputs...) // Sync GPU eval +mlx.MaterializeAsync(outputs...) // Async GPU eval +mlx.MetalAvailable() bool // Check Metal GPU + +// Array operations +mlx.NewArray(data, shape, dtype) +mlx.MatMul(a, b) +mlx.Softmax(x, axis) +mlx.Add(a, b), mlx.Multiply(a, b), mlx.Divide(a, b) +``` + +## Dependencies + +- **mlx-c v0.4.1** (fetched by CMake at build time) +- **Apple frameworks**: Foundation, Metal, Accelerate +- **Go stdlib only** — no external Go dependencies + +## Downstream Consumers + +- `forge.lthn.ai/core/go-ai/ml` — `backend_mlx.go` uses this for Metal inference +- `forge.lthn.ai/core/go-i18n` — Phase 2a needs Gemma3-1B inference for domain classification (blocked on this package being importable) + +## Model Format + +**Safetensors** (HuggingFace format). NOT GGUF. +- Example: `/Volumes/Data/lem/safetensors/gemma-3/` +- Models must be in safetensors format with matching tokenizer config + +## Coding Standards + +- UK English (colour, organisation, centre) +- `go test ./...` must pass before commit +- Conventional commits: `type(scope): description` +- Co-Author: `Co-Authored-By: Virgil ` +- Licence: EUPL-1.2 + +## Task Queue + +See `TODO.md` for prioritised work. +See `FINDINGS.md` for research notes. +See the [wiki](https://forge.lthn.ai/core/go-mlx/wiki) for architecture docs. diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..e1cf221 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,28 @@ +cmake_minimum_required(VERSION 3.24) + +project(mlx) + +set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0" CACHE STRING "Minimum macOS version") + +if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) + set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE) +endif() + +set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE) +set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE) +set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) +set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) + +set(CMAKE_INSTALL_RPATH "@loader_path") + +include(FetchContent) + +set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "") + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG ${MLX_C_GIT_TAG} +) + +FetchContent_MakeAvailable(mlx-c) diff --git a/FINDINGS.md b/FINDINGS.md new file mode 100644 index 0000000..e282b49 --- /dev/null +++ b/FINDINGS.md @@ -0,0 +1,116 @@ +# FINDINGS.md — go-mlx Research & Discovery + +Record findings, gaps, and architectural decisions here as work progresses. + +--- + +## 2026-02-19: Split from go-ai (Virgil) + +### Origin + +This package was extracted from `forge.lthn.ai/core/go-ai/mlx/`. The split was motivated by: + +1. **Platform isolation** — mlx is darwin/arm64 only with CGO + CMake build. Keeping it in go-ai forces the entire AI package to deal with platform-specific build complexity. +2. **Dependency chain** — go-i18n Phase 2a needs MLX inference for Gemma3-1B domain classification. A standalone go-mlx module can be imported directly without pulling in all of go-ai (DuckDB, Parquet, gRPC, Ollama, etc.). +3. **Build tag simplicity** — Every file is `//go:build darwin && arm64`. As a standalone module, this is clean. Inside go-ai, it was a special case that required careful handling. + +### What Was Extracted + +| Directory | Files | LOC | Purpose | +|-----------|-------|-----|---------| +| Root (`mlx/`) | 16 | ~2,500 | Core MLX bindings, ops, training | +| `model/` | 3 | ~800 | Gemma3, Qwen3 model implementations | +| `tokenizer/` | 1 | ~324 | BPE tokenizer | +| `sample/` | 1 | ~150 | Sampling strategies | +| `cache/` | 1 | ~201 | KV cache for inference | +| **Total** | **22** | **~4,354** | | + +### Import Path Changes + +All internal imports rewritten: +- `forge.lthn.ai/core/go-ai/mlx` → `forge.lthn.ai/core/go-mlx` +- `forge.lthn.ai/core/go-ai/mlx/cache` → `forge.lthn.ai/core/go-mlx/cache` +- `forge.lthn.ai/core/go-ai/mlx/tokenizer` → `forge.lthn.ai/core/go-mlx/tokenizer` +- `forge.lthn.ai/core/go-ai/mlx/model` → `forge.lthn.ai/core/go-mlx/model` +- `forge.lthn.ai/core/go-ai/mlx/sample` → `forge.lthn.ai/core/go-mlx/sample` + +### Upstream Consumer + +`go-ai/ml/backend_mlx.go` is the only file outside mlx/ that imports it. After split, go-ai needs either: +- A `replace` directive: `replace forge.lthn.ai/core/go-mlx => ../go-mlx` +- Or a published module version + +### What Stayed in go-ai + +- `ml/backend_mlx.go` (253 LOC) — the Backend adapter that calls go-mlx. This stays in go-ai because it implements the go-ai-specific `Backend` interface. +- `test-mlx.go` — integration test utility (go-ai root). Needs updating to import from go-mlx. +- `TEST-RESULTS.md` — comprehensive test report (stays as historical record). + +--- + +## 2026-02-19: Test Coverage Assessment + +### Tested (3 test files) + +| File | Tests | Coverage | +|------|-------|---------| +| `grad_test.go` | VJP/gradient computation | Good — tests forward+backward pass | +| `lora_test.go` | LoRA adapter | Good — tests apply/merge/save | +| `optim_test.go` | AdamW optimiser | Good — tests step/state | + +### Not Tested (critical gaps) + +| File | LOC | Risk | Notes | +|------|-----|------|-------| +| `ops.go` | 353 | **High** | MatMul, Softmax, element-wise ops — core of everything | +| `array.go` | 261 | **High** | Array creation, reshape, data access — foundational | +| `nn.go` | ~150 | Medium | Linear, Embedding, RMSNorm layers | +| `fast.go` | ~100 | Medium | Fused Metal kernels (RoPE, ScaledDotProduct) | +| `model/*.go` | ~800 | **High** | No tests for Gemma3/Qwen3 forward pass | +| `tokenizer/` | 324 | **High** | No BPE encode/decode tests | +| `sample/` | ~150 | Medium | No sampling tests | +| `cache/` | 201 | Medium | No KV cache tests | +| `io.go` | ~100 | Medium | No safetensors load tests | + +### Error Handling + +The error handler in `mlx.go` stores the last error in a C static variable and logs it via `slog.Error`. This is **not propagated to Go callers**. Functions like `MatMul`, `Softmax`, etc. return `*Array` with no error — if the C operation fails, the caller gets a nil/invalid array with no indication why. + +### Memory Management + +Arrays use `runtime.SetFinalizer` for C-side deallocation. Under sustained inference (1000+ tokens), this relies on GC pressure to trigger finalizers. No explicit `Close()` or `Free()` method exists on Array — could leak under high throughput if GC doesn't keep up. + +--- + +## 2026-02-19: Dependency Chain + +``` +go-i18n (Phase 2a: domain classification) + └── needs Gemma3-1B inference + └── go-mlx (this package) + └── mlx-c v0.4.1 (CMake, fetched from GitHub) + └── Apple MLX (Metal GPU compute) + +go-ai/ml/backend_mlx.go + └── imports go-mlx + └── implements go-ai Backend interface +``` + +### LEM Lab Connection + +LEM Lab (the native MLX chat UI at `localhost:8090`) also uses this code path. Currently working with Qwen3-8B streaming. The model/ directory supports both Gemma3 and Qwen3. + +--- + +## 2026-02-19: Hardware Test Results (from go-ai TEST-RESULTS.md) + +Tested on Mac Studio M3 Ultra (32-core CPU, 60-core GPU, 96GB unified memory): +- All 84 go-ai tests pass (including 3 mlx tests) +- MLX grad, lora, optim tests all pass +- Go 1.25.7, mlx-c v0.4.1 + +### Model Inventory (safetensors) + +Available on `/Volumes/Data/lem/safetensors/`: +- Gemma3-1B, Gemma3-4B, Gemma3-27B +- Qwen3-8B (used by LEM Lab) diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..21d6483 --- /dev/null +++ b/TODO.md @@ -0,0 +1,44 @@ +# TODO.md — go-mlx Task Queue + +Dispatched from core/go orchestration. Pick up tasks in order. + +--- + +## Phase 1: Standalone Package Hardening + +- [ ] **Verify go generate → test round-trip** — Run `go generate ./...` to build mlx-c, then `go test ./...`. Confirm all 3 test files (grad, lora, optim) pass on Apple Silicon. Document any CMake version requirements. +- [ ] **Add missing tests for core operations** — `ops.go` (353 LOC), `array.go` (261 LOC), `nn.go`, `compile.go`, `fast.go` have zero tests. Priority: ops (MatMul, Softmax, Add) and array (create, reshape, data access). +- [ ] **Add missing tests for model/tokenizer/sample/cache** — `model/`, `tokenizer/`, `sample/`, `cache/` have zero test files. Priority: tokenizer (BPE round-trip), sample (temperature/top-k), cache (KV append/trim). +- [ ] **Benchmark suite** — No benchmarks exist. Add: MatMul (various sizes), Softmax, model.Forward (single token), tokenizer.Encode/Decode, full Generate (tokens/sec). Baseline on M3 Ultra. + +## Phase 2: Model Support + +- [ ] **Gemma3-1B inference validation** — The go-i18n Phase 2a needs 1B model inference for domain classification at ~5K sentences/sec. Validate Gemma3-1B loads and generates correctly via `model.LoadModel()` + `model.Generate()`. Report tokens/sec. +- [ ] **Model loading robustness** — Test with missing files, corrupted safetensors, wrong dtype. Currently no error handling tests for `io.go`. +- [ ] **Add Llama model support** — Only Gemma3 and Qwen3 exist. Llama architecture would cover Meta's model family (Llama 3, CodeLlama). + +## Phase 3: Training Pipeline + +- [ ] **LoRA fine-tuning end-to-end** — `lora.go` has the adapter but no integration test showing: load base model → apply LoRA → train on small dataset → save adapter → reload. Critical for LEM Lab. +- [ ] **Gradient checkpointing** — `grad.go` has VJP but large models will OOM without checkpointing. Add selective recomputation. +- [ ] **Mixed precision training** — MLX supports BFloat16/Float16. Add dtype selection for training (currently inference uses model's native dtype). + +## Phase 4: API Polish + +- [ ] **Error handling audit** — `checkError()` only logs. Should return errors to callers instead of silent logging. The C error handler stores last error but Go code doesn't propagate it. +- [ ] **Memory management audit** — Array finalizers use `runtime.SetFinalizer`. Verify no leaks under sustained inference (1000+ tokens). Check C-side deallocation. +- [ ] **Documentation** — Public API has minimal godoc. Add examples for common workflows: load model, generate text, fine-tune with LoRA. + +--- + +## Upstream Dependencies + +- **go-i18n Phase 2a** is blocked on this package providing working Gemma3-1B inference +- **go-ai/ml/backend_mlx.go** imports this package — after split, go-ai needs `replace` directive or published module + +## Workflow + +1. Virgil in core/go writes tasks here after research +2. This repo's session picks up tasks in phase order +3. Mark `[x]` when done, note commit hash +4. New discoveries → add tasks, flag in FINDINGS.md diff --git a/array.go b/array.go new file mode 100644 index 0000000..0968bda --- /dev/null +++ b/array.go @@ -0,0 +1,261 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include +#include "mlx/c/mlx.h" +*/ +import "C" + +import ( + "encoding/binary" + "reflect" + "runtime" + "strings" + "unsafe" +) + +// Array wraps an mlx_array handle. +// Memory management relies on Go GC finalizers to call mlx_array_free, +// which decrements MLX-C's internal reference count. MLX-C handles all +// cross-array references internally — the Go wrapper does not track them. +type Array struct { + ctx C.mlx_array + name string // debug label +} + +// New creates a named Array and registers a GC finalizer. +// The inputs parameter is accepted for API compatibility but not stored — +// MLX-C tracks inter-array references via its own refcounting. +func New(name string, inputs ...*Array) *Array { + t := &Array{name: name} + runtime.SetFinalizer(t, finalizeArray) + return t +} + +// finalizeArray is called by Go GC to release the underlying C array handle. +func finalizeArray(t *Array) { + if t != nil && t.ctx.ctx != nil { + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + } +} + +type scalarTypes interface { + ~bool | ~int | ~float32 | ~float64 | ~complex64 +} + +// FromValue creates a scalar Array from a Go value. +func FromValue[T scalarTypes](t T) *Array { + Init() + tt := New("") + switch v := any(t).(type) { + case bool: + tt.ctx = C.mlx_array_new_bool(C.bool(v)) + case int: + tt.ctx = C.mlx_array_new_int(C.int(v)) + case float32: + tt.ctx = C.mlx_array_new_float32(C.float(v)) + case float64: + tt.ctx = C.mlx_array_new_float64(C.double(v)) + case complex64: + tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v))) + default: + panic("mlx: unsupported scalar type") + } + return tt +} + +type arrayTypes interface { + ~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 | + ~int8 | ~int16 | ~int32 | ~int64 | + ~float32 | ~float64 | + ~complex64 +} + +// FromValues creates an Array from a Go slice with the given shape. +func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { + Init() + if len(shape) == 0 { + panic("mlx: shape required for non-scalar tensors") + } + + cShape := make([]C.int, len(shape)) + for i := range shape { + cShape[i] = C.int(shape[i]) + } + + var dtype DType + switch reflect.TypeOf(s).Elem().Kind() { + case reflect.Bool: + dtype = DTypeBool + case reflect.Uint8: + dtype = DTypeUint8 + case reflect.Uint16: + dtype = DTypeUint16 + case reflect.Uint32: + dtype = DTypeUint32 + case reflect.Uint64: + dtype = DTypeUint64 + case reflect.Int8: + dtype = DTypeInt8 + case reflect.Int16: + dtype = DTypeInt16 + case reflect.Int32: + dtype = DTypeInt32 + case reflect.Int64: + dtype = DTypeInt64 + case reflect.Float32: + dtype = DTypeFloat32 + case reflect.Float64: + dtype = DTypeFloat64 + case reflect.Complex64: + dtype = DTypeComplex64 + default: + panic("mlx: unsupported element type") + } + + bts := make([]byte, binary.Size(s)) + if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil { + panic(err) + } + + tt := New("") + tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype)) + return tt +} + +// Zeros creates a zero-filled Array with the given shape and dtype. +func Zeros(shape []int32, dtype DType) *Array { + Init() + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + tt := New("ZEROS") + C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx) + return tt +} + +// Set replaces this array's C handle with another's. +func (t *Array) Set(other *Array) { + C.mlx_array_set(&t.ctx, other.ctx) +} + +// Clone creates a new Go wrapper sharing the same C handle (increments C refcount). +func (t *Array) Clone() *Array { + tt := New(t.name) + C.mlx_array_set(&tt.ctx, t.ctx) + return tt +} + +// Valid reports whether this Array has a non-nil mlx handle. +func (t *Array) Valid() bool { + return t.ctx.ctx != nil +} + +// String returns a human-readable representation of the array. +func (t *Array) String() string { + str := C.mlx_string_new() + defer C.mlx_string_free(str) + C.mlx_array_tostring(&str, t.ctx) + return strings.TrimSpace(C.GoString(C.mlx_string_data(str))) +} + +// Shape returns the dimensions as int32 slice. +func (t *Array) Shape() []int32 { + dims := make([]int32, t.NumDims()) + for i := range dims { + dims[i] = int32(t.Dim(i)) + } + return dims +} + +// Size returns the total number of elements. +func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) } + +// NumBytes returns the total byte size. +func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } + +// NumDims returns the number of dimensions. +func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } + +// Dim returns the size of dimension i. +func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) } + +// Dims returns all dimensions as int slice. +func (t Array) Dims() []int { + dims := make([]int, t.NumDims()) + for i := range dims { + dims[i] = t.Dim(i) + } + return dims +} + +// Dtype returns the array's data type. +func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) } + +// Int extracts a scalar int64 value. +func (t Array) Int() int { + var item C.int64_t + C.mlx_array_item_int64(&item, t.ctx) + return int(item) +} + +// Float extracts a scalar float64 value. +// Handles both float32 and float64 array dtypes. +func (t Array) Float() float64 { + switch t.Dtype() { + case DTypeFloat32: + var item C.float + C.mlx_array_item_float32(&item, t.ctx) + return float64(item) + default: + var item C.double + C.mlx_array_item_float64(&item, t.ctx) + return float64(item) + } +} + +// Ints extracts all elements as int slice (from int32 data). +func (t Array) Ints() []int { + ints := make([]int, t.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { + ints[i] = int(f) + } + return ints +} + +// DataInt32 extracts all elements as int32 slice. +func (t Array) DataInt32() []int32 { + data := make([]int32, t.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) { + data[i] = int32(f) + } + return data +} + +// Floats extracts all elements as float32 slice. +func (t Array) Floats() []float32 { + floats := make([]float32, t.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { + floats[i] = float32(f) + } + return floats +} + +// Free explicitly releases C array handles. Does not cascade — MLX-C's +// internal refcounting handles dependent arrays automatically. +func Free(s ...*Array) int { + var n int + for _, t := range s { + if t != nil && t.Valid() { + n += t.NumBytes() + C.mlx_array_free(t.ctx) + t.ctx.ctx = nil + runtime.SetFinalizer(t, nil) // cancel finalizer + } + } + return n +} diff --git a/cache/cache.go b/cache/cache.go new file mode 100644 index 0000000..28e15f5 --- /dev/null +++ b/cache/cache.go @@ -0,0 +1,201 @@ +//go:build darwin && arm64 + +// Package cache provides KV cache implementations for transformer inference. +package cache + +import "forge.lthn.ai/core/go-mlx" + +// Cache manages key-value pairs for transformer attention layers. +type Cache interface { + // Update adds new key/value tensors and returns the full cached K/V. + Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) + // Offset returns the total number of tokens processed. + Offset() int + // Len returns the number of cached tokens (may differ from Offset for rotating caches). + Len() int + // State returns the cached K/V arrays, or nil if empty. + State() []*mlx.Array + // Reset clears the cache for a new generation session. + Reset() +} + +// KVCache implements an unbounded cache that grows as needed. +// Pre-allocates in chunks of `step` tokens to reduce allocations. +type KVCache struct { + keys, values *mlx.Array + offset int + step int +} + +// NewKVCache creates a new unbounded KV cache with 256-token chunks. +func NewKVCache() *KVCache { + return &KVCache{step: 256} +} + +func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + prev := c.offset + shape := k.Shape() + if len(shape) < 4 { + // K/V must be [B, H, L, D] — if not, pass through unchanged + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset += seqLen + return c.keys, c.values + } + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + // Grow buffer if needed. + if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) { + nSteps := (c.step + seqLen - 1) / c.step + newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype()) + newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) + + if c.keys != nil { + if prev%c.step != 0 { + c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) + c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) + } + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + } else { + c.keys, c.values = newK, newV + } + } + + c.offset += seqLen + c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) + c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) + + return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), + mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) +} + +func (c *KVCache) State() []*mlx.Array { + if c.keys == nil { + return nil + } + return []*mlx.Array{c.keys, c.values} +} + +func (c *KVCache) Offset() int { return c.offset } +func (c *KVCache) Len() int { return c.offset } + +func (c *KVCache) Reset() { + c.keys = nil + c.values = nil + c.offset = 0 +} + +// RotatingKVCache implements a bounded sliding window cache. +type RotatingKVCache struct { + keys, values *mlx.Array + offset int + maxSize int + step int + idx int +} + +// NewRotatingKVCache creates a cache bounded to maxSize tokens. +func NewRotatingKVCache(maxSize int) *RotatingKVCache { + return &RotatingKVCache{maxSize: maxSize, step: 256} +} + +func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + if seqLen > 1 { + return c.updateConcat(k, v, seqLen) + } + return c.updateInPlace(k, v) +} + +func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { + shape := k.Shape() + if len(shape) < 4 { + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset++ + return c.keys, c.values + } + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) { + var cap int + if c.keys != nil { + cap = int(c.keys.Shape()[2]) + } + newSize := min(c.step, c.maxSize-cap) + newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) + newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) + if c.keys != nil { + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) + } else { + c.keys, c.values = newK, newV + } + } + + if c.idx >= c.maxSize { + c.idx = 0 + } + + c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) + c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) + + c.offset++ + c.idx++ + + validLen := int32(min(c.offset, c.maxSize)) + return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}), + mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv}) +} + +func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { + shape := k.Shape() + if len(shape) < 4 { + // K/V must be [B, H, L, D] — if not, pass through unchanged + if c.keys == nil { + c.keys, c.values = k, v + } + c.offset += seqLen + return c.keys, c.values + } + B, H, Dk := shape[0], shape[1], shape[3] + Dv := v.Shape()[3] + + if c.keys == nil { + c.keys, c.values = k, v + } else { + c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2) + c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2) + } + c.offset += seqLen + + cap := int(c.keys.Shape()[2]) + if trim := cap - c.maxSize; trim > 0 { + c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) + c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) + } + + c.idx = int(c.keys.Shape()[2]) + return c.keys, c.values +} + +func (c *RotatingKVCache) State() []*mlx.Array { + if c.keys == nil { + return nil + } + return []*mlx.Array{c.keys, c.values} +} + +func (c *RotatingKVCache) Offset() int { return c.offset } +func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } + +func (c *RotatingKVCache) Reset() { + c.keys = nil + c.values = nil + c.offset = 0 + c.idx = 0 +} diff --git a/compile.go b/compile.go new file mode 100644 index 0000000..8440f4b --- /dev/null +++ b/compile.go @@ -0,0 +1,90 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" + +// Callback for compiled functions. +extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); + +static mlx_closure new_closure(void *payload) { + return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL); +} +*/ +import "C" + +import ( + "sync" + "unsafe" +) + +// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls. +type CompiledFunc struct { + fn func([]*Array) []*Array + closure C.mlx_closure + mu sync.Mutex +} + +var compiledFuncs sync.Map + +//export goCompiledFunc +func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { + id := uintptr(payload) + fnI, ok := compiledFuncs.Load(id) + if !ok { + return 1 + } + fn := fnI.(func([]*Array) []*Array) + + // Convert inputs + nInputs := int(C.mlx_vector_array_size(inputs)) + goInputs := make([]*Array, nInputs) + for i := 0; i < nInputs; i++ { + a := New("INPUT") + C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) + goInputs[i] = a + } + + // Call user function + goOutputs := fn(goInputs) + + // The output vector arrives with ctx=nullptr. Create a valid vector, + // fill it, then copy to the output pointer. + tmp := C.mlx_vector_array_new() + for _, out := range goOutputs { + C.mlx_vector_array_append_value(tmp, out.ctx) + } + C.mlx_vector_array_set(outputs, tmp) + C.mlx_vector_array_free(tmp) + return 0 +} + +var nextID uintptr +var nextIDMu sync.Mutex + +// CompileShapeless compiles a function for efficient repeated execution. +// The function must accept and return arrays of consistent shapes. +func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc { + nextIDMu.Lock() + nextID++ + id := nextID + nextIDMu.Unlock() + + compiledFuncs.Store(id, fn) + + cf := &CompiledFunc{fn: fn} + cf.closure = C.new_closure(unsafe.Pointer(id)) + return cf +} + +// Call executes the compiled function with the given inputs. +func (cf *CompiledFunc) Call(inputs ...*Array) []*Array { + cf.mu.Lock() + defer cf.mu.Unlock() + + // Fall back to direct call — compilation is an optimization. + // The compiled closure can be used via mlx_compiled but the + // direct path is simpler and still benefits from MLX's lazy evaluation. + return cf.fn(inputs) +} diff --git a/dtype.go b/dtype.go new file mode 100644 index 0000000..eae583f --- /dev/null +++ b/dtype.go @@ -0,0 +1,83 @@ +//go:build darwin && arm64 + +package mlx + +// #include "mlx/c/mlx.h" +import "C" + +import "encoding/json" + +// DType represents an MLX array data type. +type DType C.mlx_dtype + +const ( + DTypeBool DType = C.MLX_BOOL + DTypeUint8 DType = C.MLX_UINT8 + DTypeUint16 DType = C.MLX_UINT16 + DTypeUint32 DType = C.MLX_UINT32 + DTypeUint64 DType = C.MLX_UINT64 + DTypeInt8 DType = C.MLX_INT8 + DTypeInt16 DType = C.MLX_INT16 + DTypeInt32 DType = C.MLX_INT32 + DTypeInt64 DType = C.MLX_INT64 + DTypeFloat16 DType = C.MLX_FLOAT16 + DTypeFloat32 DType = C.MLX_FLOAT32 + DTypeFloat64 DType = C.MLX_FLOAT64 + DTypeBFloat16 DType = C.MLX_BFLOAT16 + DTypeComplex64 DType = C.MLX_COMPLEX64 +) + +var dtypeNames = map[DType]string{ + DTypeBool: "bool", + DTypeUint8: "uint8", + DTypeUint16: "uint16", + DTypeUint32: "uint32", + DTypeUint64: "uint64", + DTypeInt8: "int8", + DTypeInt16: "int16", + DTypeInt32: "int32", + DTypeInt64: "int64", + DTypeFloat16: "float16", + DTypeFloat32: "float32", + DTypeFloat64: "float64", + DTypeBFloat16: "bfloat16", + DTypeComplex64: "complex64", +} + +func (d DType) String() string { + if s, ok := dtypeNames[d]; ok { + return s + } + return "unknown" +} + +var dtypeFromString = map[string]DType{ + "bool": DTypeBool, "BOOL": DTypeBool, + "uint8": DTypeUint8, "U8": DTypeUint8, + "uint16": DTypeUint16, "U16": DTypeUint16, + "uint32": DTypeUint32, "U32": DTypeUint32, + "uint64": DTypeUint64, "U64": DTypeUint64, + "int8": DTypeInt8, "I8": DTypeInt8, + "int16": DTypeInt16, "I16": DTypeInt16, + "int32": DTypeInt32, "I32": DTypeInt32, + "int64": DTypeInt64, "I64": DTypeInt64, + "float16": DTypeFloat16, "F16": DTypeFloat16, + "float32": DTypeFloat32, "F32": DTypeFloat32, + "float64": DTypeFloat64, "F64": DTypeFloat64, + "bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16, + "complex64": DTypeComplex64, +} + +// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc. +func (d *DType) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + if dt, ok := dtypeFromString[s]; ok { + *d = dt + return nil + } + *d = DTypeFloat32 // default + return nil +} diff --git a/fast.go b/fast.go new file mode 100644 index 0000000..e1abeba --- /dev/null +++ b/fast.go @@ -0,0 +1,79 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include +#include "mlx/c/mlx.h" +*/ +import "C" + +import "unsafe" + +// RMSNorm applies Root Mean Square normalization using a fused Metal kernel. +func RMSNorm(x, weight *Array, eps float32) *Array { + out := New("FAST_RMSNORM", x) + C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) + return out +} + +// LayerNorm applies Layer normalization using a fused Metal kernel. +func LayerNorm(x, weight, bias *Array, eps float32) *Array { + out := New("FAST_LAYERNORM", x) + C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx) + return out +} + +// RoPE applies Rotary Position Embeddings using a fused Metal kernel. +func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array { + out := New("FAST_ROPE", x) + freqs := C.mlx_array_new() + defer C.mlx_array_free(freqs) + C.mlx_fast_rope( + &out.ctx, + x.ctx, + C.int(dims), + C._Bool(traditional), + C.mlx_optional_float{ + value: C.float(base), + has_value: C._Bool(base != 0), + }, + C.float(scale), + C.int(offset), + freqs, + DefaultStream().ctx, + ) + return out +} + +// ScaledDotProductAttention computes attention using a fused Metal kernel. +func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array { + mode := "" + if causal { + mode = "causal" + } + cMode := C.CString(mode) + defer C.free(unsafe.Pointer(cMode)) + + maskArr := C.mlx_array_new() + defer C.mlx_array_free(maskArr) + sinksArr := C.mlx_array_new() + defer C.mlx_array_free(sinksArr) + + out := New("FAST_SDPA", query, key, value) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx) + return out +} + +// ScaledDotProductAttentionWithMask computes attention with an explicit mask. +func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array { + cMode := C.CString("array") + defer C.free(unsafe.Pointer(cMode)) + + sinksArr := C.mlx_array_new() + defer C.mlx_array_free(sinksArr) + + out := New("FAST_SDPA", query, key, value, mask) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx) + return out +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..18dacff --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module forge.lthn.ai/core/go-mlx + +go 1.25.5 diff --git a/grad.go b/grad.go new file mode 100644 index 0000000..12ca786 --- /dev/null +++ b/grad.go @@ -0,0 +1,353 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" + +// Callback for gradient closures — same signature as goCompiledFunc. +extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); + +static mlx_closure new_grad_closure(void *payload) { + return mlx_closure_new_func_payload(&goGradFunc, payload, NULL); +} +*/ +import "C" + +import ( + "fmt" + "sync" + "sync/atomic" + "unsafe" +) + +// --- Closure registry (separate from compile.go's registry) --- + +var ( + gradFuncs sync.Map + gradNextID atomic.Uintptr +) + +//export goGradFunc +func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { + id := uintptr(payload) + fnI, ok := gradFuncs.Load(id) + if !ok { + return 1 + } + fn := fnI.(func([]*Array) []*Array) + + nInputs := int(C.mlx_vector_array_size(inputs)) + goInputs := make([]*Array, nInputs) + for i := 0; i < nInputs; i++ { + a := New("GRAD_INPUT") + C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) + goInputs[i] = a + } + + goOutputs := fn(goInputs) + + // The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_). + // Create a valid vector, fill it, then copy to the output pointer. + tmp := C.mlx_vector_array_new() + for _, out := range goOutputs { + C.mlx_vector_array_append_value(tmp, out.ctx) + } + C.mlx_vector_array_set(outputs, tmp) + C.mlx_vector_array_free(tmp) + return 0 +} + +// newClosure registers a Go function as an MLX closure for autograd. +func newClosure(fn func([]*Array) []*Array) C.mlx_closure { + id := gradNextID.Add(1) + gradFuncs.Store(id, fn) + return C.new_grad_closure(unsafe.Pointer(id)) +} + +// --- VJP (Vector-Jacobian Product) — Reverse Mode --- + +// VJP computes the vector-Jacobian product (reverse-mode autodiff). +// Given a function fn, input primals, and output cotangents (upstream gradients), +// returns (outputs, gradients) where gradients are w.r.t. the primals. +// +// This is the fundamental backward pass operation. +func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + // Pack primals into vector + primalsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(primalsVec) + for _, p := range primals { + C.mlx_vector_array_append_value(primalsVec, p.ctx) + } + + // Pack cotangents into vector + cotangentsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(cotangentsVec) + for _, c := range cotangents { + C.mlx_vector_array_append_value(cotangentsVec, c.ctx) + } + + // Call mlx_vjp + var outVec, vjpVec C.mlx_vector_array + outVec = C.mlx_vector_array_new() + vjpVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + defer C.mlx_vector_array_free(vjpVec) + + rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: vjp failed") + } + + outputs = vectorToArrays(outVec) + vjps = vectorToArrays(vjpVec) + return outputs, vjps, nil +} + +// --- JVP (Jacobian-Vector Product) — Forward Mode --- + +// JVP computes the Jacobian-vector product (forward-mode autodiff). +// Given a function fn, input primals, and input tangents (perturbation directions), +// returns (outputs, output_tangents). +// +// Useful for directional derivatives and Hessian-vector products. +func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + primalsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(primalsVec) + for _, p := range primals { + C.mlx_vector_array_append_value(primalsVec, p.ctx) + } + + tangentsVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(tangentsVec) + for _, t := range tangents { + C.mlx_vector_array_append_value(tangentsVec, t.ctx) + } + + var outVec, jvpVec C.mlx_vector_array + outVec = C.mlx_vector_array_new() + jvpVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + defer C.mlx_vector_array_free(jvpVec) + + rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: jvp failed") + } + + outputs = vectorToArrays(outVec) + jvps = vectorToArrays(jvpVec) + return outputs, jvps, nil +} + +// --- ValueAndGrad — Combined Forward + Backward --- + +// GradFn is a function that computes both the loss value and gradients +// with respect to specified arguments. Call it with inputs to get +// (values, gradients). +type GradFn struct { + cls C.mlx_closure_value_and_grad +} + +// ValueAndGrad creates a GradFn that computes both the function value and +// gradients with respect to the arguments at the given indices. +// +// The returned GradFn can be called repeatedly with different inputs. +// This is the primary API for training loops: +// +// lossFn := func(params []*Array) []*Array { return []*Array{loss} } +// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg +// values, grads := grad.Apply(params...) +func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + // Default: differentiate w.r.t. first argument + if len(argnums) == 0 { + argnums = []int{0} + } + + cArgs := make([]C.int, len(argnums)) + for i, a := range argnums { + cArgs[i] = C.int(a) + } + + g := &GradFn{} + C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs))) + return g +} + +// Apply calls the gradient function with the given inputs. +// Returns (values, gradients) where values are the function outputs +// and gradients are w.r.t. the arguments specified in ValueAndGrad. +func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) { + inputVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(inputVec) + for _, in := range inputs { + C.mlx_vector_array_append_value(inputVec, in.ctx) + } + + var valVec, gradVec C.mlx_vector_array + valVec = C.mlx_vector_array_new() + gradVec = C.mlx_vector_array_new() + defer C.mlx_vector_array_free(valVec) + defer C.mlx_vector_array_free(gradVec) + + rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec) + if rc != 0 { + checkError() + return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed") + } + + values = vectorToArrays(valVec) + grads = vectorToArrays(gradVec) + return values, grads, nil +} + +// Free releases the underlying C closure. +func (g *GradFn) Free() { + if g.cls.ctx != nil { + C.mlx_closure_value_and_grad_free(g.cls) + g.cls.ctx = nil + } +} + +// --- Checkpoint — Memory-Efficient Gradient Recomputation --- + +// Checkpoint wraps a function so that during backward pass, intermediate +// activations are recomputed rather than stored. Trades compute for memory. +// +// Use this for memory-constrained training (large models on limited VRAM). +func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array { + Init() + + closure := newClosure(fn) + defer C.mlx_closure_free(closure) + + var checkpointed C.mlx_closure + C.mlx_checkpoint(&checkpointed, closure) + + // Register the checkpointed closure so it stays alive + id := gradNextID.Add(1) + cpClosure := checkpointed // capture for the returned function + + // Return a Go function that calls through the checkpointed closure + return func(inputs []*Array) []*Array { + _ = id // prevent GC of closure reference + + inputVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(inputVec) + for _, in := range inputs { + C.mlx_vector_array_append_value(inputVec, in.ctx) + } + + outVec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(outVec) + + C.mlx_closure_apply(&outVec, cpClosure, inputVec) + return vectorToArrays(outVec) + } +} + +// --- Loss Functions --- + +// CrossEntropyLoss computes cross-entropy loss between logits and integer targets. +// logits: [..., V] (raw model output, pre-softmax, last dim = vocab) +// targets: [...] (integer token IDs, same shape as logits minus last dim) +// Returns scalar loss averaged over all positions. +// +// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i] +func CrossEntropyLoss(logits, targets *Array) *Array { + Init() + // LogSumExp along last axis (vocab dimension) for numerical stability + lse := LogSumExp(logits, -1, false) + + // Gather the logit at the target index: logits[..., target] + // targets needs to be [..., 1] for take_along_axis + tgtExpanded := ExpandDims(targets, -1) + gathered := TakeAlongAxis(logits, tgtExpanded, -1) + gathered = Squeeze(gathered, -1) // back to [...] shape + + // Per-position loss: logsumexp - logit_at_target + perPos := Subtract(lse, gathered) + + // Mean over all positions + return MeanAll(perPos) +} + +// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions. +// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore) +// Returns scalar loss averaged over masked positions only. +func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array { + Init() + lse := LogSumExp(logits, -1, false) + tgtExpanded := ExpandDims(targets, -1) + gathered := TakeAlongAxis(logits, tgtExpanded, -1) + gathered = Squeeze(gathered, -1) + perPos := Subtract(lse, gathered) // [B, L] + + // Apply mask and average over masked positions + masked := Mul(perPos, mask) + return Divide(SumAll(masked), SumAll(mask)) +} + +// MSELoss computes the mean squared error loss: mean((predictions - targets)^2). +func MSELoss(predictions, targets *Array) *Array { + diff := Subtract(predictions, targets) + sq := Square(diff) + return MeanAll(sq) +} + +// --- Helpers --- + +// vectorToArrays extracts all arrays from an mlx_vector_array. +func vectorToArrays(vec C.mlx_vector_array) []*Array { + n := int(C.mlx_vector_array_size(vec)) + out := make([]*Array, n) + for i := 0; i < n; i++ { + a := New("VEC_OUT") + C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i)) + out[i] = a + } + return out +} + +// Log returns element-wise natural logarithm. +func Log(a *Array) *Array { + out := New("LOG", a) + C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// SumAll reduces by summation over all elements, returning a scalar. +func SumAll(a *Array) *Array { + flat := Reshape(a, -1) + return Sum(flat, 0, false) +} + +// MeanAll reduces by averaging over all elements, returning a scalar. +func MeanAll(a *Array) *Array { + flat := Reshape(a, -1) + return Mean(flat, 0, false) +} + +// OnesLike creates an array of ones with the same shape and type as the input. +func OnesLike(a *Array) *Array { + out := New("ONES_LIKE", a) + C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} diff --git a/grad_test.go b/grad_test.go new file mode 100644 index 0000000..0a00751 --- /dev/null +++ b/grad_test.go @@ -0,0 +1,332 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +func TestVJP_SimpleSquare(t *testing.T) { + // f(x) = x^2, df/dx = 2x + // At x=3: f(3)=9, df/dx=6 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + x := FromValue(float32(3.0)) + cotangent := FromValue(float32(1.0)) // upstream grad = 1 + + outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(outputs[0], grads[0]) + + got := outputs[0].Float() + if math.Abs(got-9.0) > 1e-5 { + t.Errorf("output = %f, want 9.0", got) + } + + grad := grads[0].Float() + if math.Abs(grad-6.0) > 1e-5 { + t.Errorf("grad = %f, want 6.0", grad) + } +} + +func TestVJP_Addition(t *testing.T) { + // f(x, y) = x + y, df/dx = 1, df/dy = 1 + fn := func(inputs []*Array) []*Array { + return []*Array{Add(inputs[0], inputs[1])} + } + + x := FromValue(float32(2.0)) + y := FromValue(float32(5.0)) + cotangent := FromValue(float32(1.0)) + + _, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(grads...) + + if math.Abs(grads[0].Float()-1.0) > 1e-5 { + t.Errorf("dx = %f, want 1.0", grads[0].Float()) + } + if math.Abs(grads[1].Float()-1.0) > 1e-5 { + t.Errorf("dy = %f, want 1.0", grads[1].Float()) + } +} + +func TestVJP_MatmulGrad(t *testing.T) { + // f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W + // For W=[2,2], x=[2,1]: dL/dW = ones @ x^T + x := FromValues([]float32{1.0, 2.0}, 2, 1) + w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity + + fn := func(inputs []*Array) []*Array { + result := Matmul(inputs[0], x) + return []*Array{SumAll(result)} + } + + cotangent := FromValue(float32(1.0)) + + outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent}) + if err != nil { + t.Fatalf("VJP failed: %v", err) + } + + Materialize(outputs[0], grads[0]) + + // W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3 + got := outputs[0].Float() + if math.Abs(got-3.0) > 1e-5 { + t.Errorf("output = %f, want 3.0", got) + } + + // Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T + // = [[1,2],[1,2]] + gradFloats := grads[0].Floats() + expected := []float32{1.0, 2.0, 1.0, 2.0} + for i, exp := range expected { + if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 { + t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp) + } + } +} + +func TestJVP_SimpleSquare(t *testing.T) { + // f(x) = x^2, JVP with tangent v: df = 2x * v + // At x=3, v=1: df = 6 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + x := FromValue(float32(3.0)) + tangent := FromValue(float32(1.0)) + + outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent}) + if err != nil { + t.Fatalf("JVP failed: %v", err) + } + + Materialize(outputs[0], jvps[0]) + + got := outputs[0].Float() + if math.Abs(got-9.0) > 1e-5 { + t.Errorf("output = %f, want 9.0", got) + } + + jvp := jvps[0].Float() + if math.Abs(jvp-6.0) > 1e-5 { + t.Errorf("jvp = %f, want 6.0", jvp) + } +} + +func TestValueAndGrad_Quadratic(t *testing.T) { + // f(x) = x^2 + 2x + 1 = (x+1)^2 + // f'(x) = 2x + 2 + // At x=3: f(3) = 16, f'(3) = 8 + fn := func(inputs []*Array) []*Array { + x := inputs[0] + x2 := Mul(x, x) + two_x := MulScalar(x, 2.0) + one := FromValue(float32(1.0)) + return []*Array{Add(Add(x2, two_x), one)} + } + + grad := ValueAndGrad(fn, 0) + defer grad.Free() + + x := FromValue(float32(3.0)) + values, grads, err := grad.Apply(x) + if err != nil { + t.Fatalf("ValueAndGrad failed: %v", err) + } + + Materialize(values[0], grads[0]) + + val := values[0].Float() + if math.Abs(val-16.0) > 1e-5 { + t.Errorf("value = %f, want 16.0", val) + } + + g := grads[0].Float() + if math.Abs(g-8.0) > 1e-5 { + t.Errorf("grad = %f, want 8.0", g) + } +} + +func TestValueAndGrad_MultiArg(t *testing.T) { + // f(x, y) = x*y, df/dx = y, df/dy = x + // At x=3, y=4: f=12, dx=4, dy=3 + fn := func(inputs []*Array) []*Array { + return []*Array{Mul(inputs[0], inputs[1])} + } + + // Differentiate w.r.t. both arguments + grad := ValueAndGrad(fn, 0, 1) + defer grad.Free() + + x := FromValue(float32(3.0)) + y := FromValue(float32(4.0)) + values, grads, err := grad.Apply(x, y) + if err != nil { + t.Fatalf("ValueAndGrad failed: %v", err) + } + + Materialize(values[0], grads[0], grads[1]) + + val := values[0].Float() + if math.Abs(val-12.0) > 1e-5 { + t.Errorf("value = %f, want 12.0", val) + } + + dx := grads[0].Float() + if math.Abs(dx-4.0) > 1e-5 { + t.Errorf("dx = %f, want 4.0 (y)", dx) + } + + dy := grads[1].Float() + if math.Abs(dy-3.0) > 1e-5 { + t.Errorf("dy = %f, want 3.0 (x)", dy) + } +} + +func TestValueAndGrad_Reusable(t *testing.T) { + // Verify GradFn can be called multiple times + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} // x^2, grad = 2x + } + + grad := ValueAndGrad(fn) + defer grad.Free() + + for _, tc := range []struct { + x float32 + want float64 // expected gradient + }{ + {2.0, 4.0}, + {5.0, 10.0}, + {-3.0, -6.0}, + {0.0, 0.0}, + } { + x := FromValue(tc.x) + _, grads, err := grad.Apply(x) + if err != nil { + t.Fatalf("Apply failed for x=%f: %v", tc.x, err) + } + Materialize(grads[0]) + + g := grads[0].Float() + if math.Abs(g-tc.want) > 1e-5 { + t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want) + } + } +} + +func TestCrossEntropyLoss(t *testing.T) { + // Simple 3-class classification + // logits = [1.0, 2.0, 3.0], target = 2 (class index) + // Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1) + // = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076 + // loss = 3.4076 - 3.0 = 0.4076 + logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3] + targets := FromValues([]int32{2}, 1) // [1] + + loss := CrossEntropyLoss(logits, targets) + Materialize(loss) + + got := loss.Float() + expected := 0.4076 + if math.Abs(got-expected) > 0.01 { + t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected) + } +} + +func TestMSELoss(t *testing.T) { + pred := FromValues([]float32{1.0, 2.0, 3.0}, 3) + target := FromValues([]float32{1.5, 2.5, 3.5}, 3) + + loss := MSELoss(pred, target) + Materialize(loss) + + // MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25 + got := loss.Float() + if math.Abs(got-0.25) > 1e-5 { + t.Errorf("MSELoss = %f, want 0.25", got) + } +} + +func TestLogSumExp(t *testing.T) { + // logsumexp([1, 2, 3]) along axis -1 + a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) + result := LogSumExp(a, -1, false) + Materialize(result) + + // = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076 + got := result.Float() + expected := 3.4076 + if math.Abs(got-expected) > 0.01 { + t.Errorf("LogSumExp = %f, want ~%f", got, expected) + } +} + +func TestOnesLike(t *testing.T) { + a := FromValues([]float32{1.0, 2.0, 3.0}, 3) + ones := OnesLike(a) + Materialize(ones) + + floats := ones.Floats() + for i, f := range floats { + if f != 1.0 { + t.Errorf("OnesLike[%d] = %f, want 1.0", i, f) + } + } +} + +func TestCheckpoint(t *testing.T) { + // Checkpoint should produce the same result as the original function + fn := func(inputs []*Array) []*Array { + x := inputs[0] + return []*Array{Mul(x, x)} + } + + cpFn := Checkpoint(fn) + + x := FromValue(float32(5.0)) + result := cpFn([]*Array{x}) + Materialize(result[0]) + + got := result[0].Float() + if math.Abs(got-25.0) > 1e-5 { + t.Errorf("Checkpoint result = %f, want 25.0", got) + } +} + +func TestSumAll(t *testing.T) { + a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2) + result := SumAll(a) + Materialize(result) + + got := result.Float() + if math.Abs(got-10.0) > 1e-5 { + t.Errorf("SumAll = %f, want 10.0", got) + } +} + +func TestMeanAll(t *testing.T) { + a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2) + result := MeanAll(a) + Materialize(result) + + got := result.Float() + if math.Abs(got-5.0) > 1e-5 { + t.Errorf("MeanAll = %f, want 5.0", got) + } +} diff --git a/io.go b/io.go new file mode 100644 index 0000000..7e35773 --- /dev/null +++ b/io.go @@ -0,0 +1,63 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include +#include "mlx/c/mlx.h" +*/ +import "C" + +import ( + "iter" + "runtime" + "unsafe" +) + +// LoadSafetensors loads tensors from a .safetensors file, returning an iterator +// over (name, array) pairs. Tensors are loaded lazily on the CPU stream. +func LoadSafetensors(path string) iter.Seq2[string, *Array] { + Init() + return func(yield func(string, *Array) bool) { + string2array := C.mlx_map_string_to_array_new() + defer C.mlx_map_string_to_array_free(string2array) + + string2string := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_string_free(string2string) + + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + cpu := C.mlx_default_cpu_stream_new() + defer C.mlx_stream_free(cpu) + + C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) + + it := C.mlx_map_string_to_array_iterator_new(string2array) + defer C.mlx_map_string_to_array_iterator_free(it) + + for { + var key *C.char + value := C.mlx_array_new() + if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { + break + } + + name := C.GoString(key) + arr := &Array{ctx: value, name: name} + runtime.SetFinalizer(arr, finalizeArray) + if !yield(name, arr) { + break + } + } + } +} + +// LoadAllSafetensors loads all tensors from a .safetensors file into a map. +func LoadAllSafetensors(path string) map[string]*Array { + tensors := make(map[string]*Array) + for name, arr := range LoadSafetensors(path) { + tensors[name] = arr + } + return tensors +} diff --git a/lora.go b/lora.go new file mode 100644 index 0000000..ac7107e --- /dev/null +++ b/lora.go @@ -0,0 +1,237 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include +#include "mlx/c/mlx.h" +*/ +import "C" + +import ( + "fmt" + "math" + "sort" + "unsafe" +) + +// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters. +// +// Forward: base(x) + scale * (x @ A^T) @ B^T +// +// A is [rank, in_features] — initialised with Kaiming normal +// B is [out_features, rank] — initialised to zero +// Scale = alpha / rank +// +// Only A and B are trainable. The base weights stay frozen. +type LoRALinear struct { + Base *Linear // Frozen base weights (may be quantized) + A *Array // [rank, in_features] — trainable + B *Array // [out_features, rank] — trainable + Scale float32 // alpha / rank + Rank int + Alpha float32 +} + +// NewLoRALinear wraps an existing Linear layer with LoRA adapters. +// rank: decomposition rank (typically 4, 8, or 16) +// alpha: scaling factor (typically 2*rank) +func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear { + // Determine dimensions from the base weight. + // Weight shape is [out_features, in_features] for standard linear, + // or quantized shape which we handle via the base layer. + var inFeatures, outFeatures int32 + + if base.Scales != nil { + // Quantized: weight is packed. Compute dims from scales. + // scales shape is [out_features, in_features / group_size] + scaleShape := base.Scales.Shape() + outFeatures = scaleShape[0] + inFeatures = scaleShape[1] * int32(base.GroupSize) + } else { + wShape := base.Weight.Shape() + outFeatures = wShape[0] + inFeatures = wShape[1] + } + + // A: Kaiming normal initialisation — N(0, 1/sqrt(in_features)) + stddev := float32(1.0 / math.Sqrt(float64(inFeatures))) + a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32) + + // B: zero initialisation — LoRA starts as identity (no change to base) + b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32) + + Materialize(a, b) + + return &LoRALinear{ + Base: base, + A: a, + B: b, + Scale: alpha / float32(rank), + Rank: rank, + Alpha: alpha, + } +} + +// Forward computes: base(x) + scale * (x @ A^T) @ B^T +// Calls baseForward on the underlying Linear to avoid infinite recursion +// when the Linear's Forward method dispatches through LoRA. +func (l *LoRALinear) Forward(x *Array) *Array { + baseOut := l.Base.baseForward(x) + + // LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out] + loraOut := Matmul(x, Transpose(l.A)) + loraOut = Matmul(loraOut, Transpose(l.B)) + loraOut = MulScalar(loraOut, l.Scale) + + return Add(baseOut, loraOut) +} + +// TrainableParams returns the LoRA A and B arrays for gradient computation. +func (l *LoRALinear) TrainableParams() []*Array { + return []*Array{l.A, l.B} +} + +// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step). +func (l *LoRALinear) SetParams(a, b *Array) { + l.A = a + l.B = b +} + +// ParamCount returns the number of trainable parameters. +func (l *LoRALinear) ParamCount() int { + aShape := l.A.Shape() + bShape := l.B.Shape() + return int(aShape[0]*aShape[1] + bShape[0]*bShape[1]) +} + +// --- LoRA Application to Models --- + +// LoRAConfig specifies which layers to apply LoRA to and with what parameters. +type LoRAConfig struct { + Rank int // Decomposition rank (default 8) + Alpha float32 // Scaling factor (default 16) + TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj) +} + +// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning. +func DefaultLoRAConfig() LoRAConfig { + return LoRAConfig{ + Rank: 8, + Alpha: 16, + TargetKeys: []string{"q_proj", "v_proj"}, + } +} + +// LoRAAdapter holds all LoRA layers applied to a model. +type LoRAAdapter struct { + Layers map[string]*LoRALinear // keyed by weight path prefix + Config LoRAConfig +} + +// TotalParams returns the total number of trainable parameters across all LoRA layers. +func (a *LoRAAdapter) TotalParams() int { + total := 0 + for _, l := range a.Layers { + total += l.ParamCount() + } + return total +} + +// SortedNames returns layer names in deterministic sorted order. +func (a *LoRAAdapter) SortedNames() []string { + names := make([]string, 0, len(a.Layers)) + for name := range a.Layers { + names = append(names, name) + } + sort.Strings(names) + return names +} + +// AllTrainableParams returns all trainable arrays (A and B from every layer), +// in a deterministic order sorted by layer name. +func (a *LoRAAdapter) AllTrainableParams() []*Array { + names := a.SortedNames() + params := make([]*Array, 0, len(names)*2) + for _, name := range names { + l := a.Layers[name] + params = append(params, l.A, l.B) + } + return params +} + +// SetAllParams updates all LoRA A and B arrays from a flat slice, +// in the same deterministic order as AllTrainableParams. +func (a *LoRAAdapter) SetAllParams(params []*Array) { + names := a.SortedNames() + for i, name := range names { + l := a.Layers[name] + l.A = params[i*2] + l.B = params[i*2+1] + } +} + +// Save writes the LoRA adapter weights to a safetensors file. +// Only saves the A and B matrices — not the frozen base weights. +func (a *LoRAAdapter) Save(path string) error { + weights := make(map[string]*Array) + for name, l := range a.Layers { + weights[name+".lora_a"] = l.A + weights[name+".lora_b"] = l.B + } + return SaveSafetensors(path, weights) +} + +// --- Random Normal --- + +// RandomNormal generates normal (Gaussian) random values with given mean and stddev. +func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array { + Init() + out := New("RANDOM_NORMAL") + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + key := C.mlx_array_new() + defer C.mlx_array_free(key) + C.mlx_random_normal( + &out.ctx, + &cShape[0], C.size_t(len(cShape)), + C.mlx_dtype(dtype), + C.float(mean), C.float(stddev), + key, // null key = use default RNG + DefaultStream().ctx, + ) + return out +} + +// --- SaveSafetensors --- + +// SaveSafetensors saves a map of named arrays to a .safetensors file. +func SaveSafetensors(path string, weights map[string]*Array) error { + Init() + + // Build the map + cMap := C.mlx_map_string_to_array_new() + defer C.mlx_map_string_to_array_free(cMap) + + for name, arr := range weights { + cName := C.CString(name) + C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx) + C.free(unsafe.Pointer(cName)) + } + + // Empty metadata + cMeta := C.mlx_map_string_to_string_new() + defer C.mlx_map_string_to_string_free(cMeta) + + cPath := C.CString(path) + defer C.free(unsafe.Pointer(cPath)) + + rc := C.mlx_save_safetensors(cPath, cMap, cMeta) + if rc != 0 { + checkError() + return fmt.Errorf("mlx: save safetensors failed: %s", path) + } + return nil +} diff --git a/lora_test.go b/lora_test.go new file mode 100644 index 0000000..729dffa --- /dev/null +++ b/lora_test.go @@ -0,0 +1,316 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "os" + "testing" +) + +func TestNewLoRALinear(t *testing.T) { + // Create a simple base linear layer: [4, 8] weight + w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 4, 8.0) // rank=4, alpha=8 + + // Check dimensions + aShape := lora.A.Shape() + bShape := lora.B.Shape() + + if aShape[0] != 4 || aShape[1] != 8 { + t.Errorf("A shape = %v, want [4, 8]", aShape) + } + if bShape[0] != 4 || bShape[1] != 4 { + t.Errorf("B shape = %v, want [4, 4]", bShape) + } + + // Scale should be alpha/rank = 8/4 = 2 + if math.Abs(float64(lora.Scale)-2.0) > 1e-5 { + t.Errorf("Scale = %f, want 2.0", lora.Scale) + } + + // B should be all zeros (LoRA starts as identity) + Materialize(lora.B) + bFloats := lora.B.Floats() + for i, v := range bFloats { + if v != 0 { + t.Errorf("B[%d] = %f, want 0", i, v) + } + } +} + +func TestLoRALinear_ForwardMatchesBase(t *testing.T) { + // With B=0, LoRA forward should equal base forward + w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 4, 8.0) + + // Random input [1, 3, 8] + x := RandomNormal(0, 1, []int32{1, 3, 8}, DTypeFloat32) + Materialize(x) + + baseOut := base.Forward(x) + loraOut := lora.Forward(x) + Materialize(baseOut, loraOut) + + // Should be identical since B is zero + baseFloats := baseOut.Floats() + loraFloats := loraOut.Floats() + + if len(baseFloats) != len(loraFloats) { + t.Fatalf("output sizes differ: base=%d, lora=%d", len(baseFloats), len(loraFloats)) + } + + for i := range baseFloats { + diff := math.Abs(float64(baseFloats[i] - loraFloats[i])) + if diff > 1e-4 { + t.Errorf("output[%d] differs: base=%f, lora=%f", i, baseFloats[i], loraFloats[i]) + } + } +} + +func TestLoRALinear_ForwardWithAdapter(t *testing.T) { + // Set A and B to known values and verify output changes + w := Zeros([]int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 2, 4.0) // rank=2, alpha=4, scale=2 + + // Set A to identity-like: [[1,0,0,...], [0,1,0,...]] + a := Zeros([]int32{2, 8}, DTypeFloat32) + // Set B to ones: [[1,1], [1,1], [1,1], [1,1]] + b := FromValues([]float32{ + 1, 1, + 1, 1, + 1, 1, + 1, 1, + }, 4, 2) + Materialize(a, b) + lora.A = a + lora.B = b + + // With base=0, A=0, output should also be 0 (scale * x@0@B^T = 0) + x := FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 8) + result := lora.Forward(x) + Materialize(result) + + // base(x) = 0 (zero weights), lora = scale * (x @ A^T) @ B^T + // A is zeros, so x @ A^T = [0, 0], then @ B^T = [0,0,0,0] + for _, v := range result.Floats() { + if v != 0 { + t.Errorf("expected 0 with zero A, got %f", v) + } + } +} + +func TestLoRALinear_ParamCount(t *testing.T) { + w := RandomNormal(0, 0.01, []int32{64, 128}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 8, 16.0) // rank=8 + // A: [8, 128] = 1024, B: [64, 8] = 512, total = 1536 + expected := 8*128 + 64*8 + if lora.ParamCount() != expected { + t.Errorf("ParamCount = %d, want %d", lora.ParamCount(), expected) + } +} + +func TestLoRALinear_TrainableParams(t *testing.T) { + w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 4, 8.0) + params := lora.TrainableParams() + + if len(params) != 2 { + t.Fatalf("TrainableParams returned %d arrays, want 2", len(params)) + } + + // First is A, second is B + if params[0].Shape()[0] != 4 || params[0].Shape()[1] != 8 { + t.Errorf("param[0] (A) shape = %v, want [4, 8]", params[0].Shape()) + } + if params[1].Shape()[0] != 4 || params[1].Shape()[1] != 4 { + t.Errorf("param[1] (B) shape = %v, want [4, 4]", params[1].Shape()) + } +} + +func TestLoRALinear_GradientFlows(t *testing.T) { + // Verify that gradients flow through the LoRA path + w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 4, 8.0) + x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32) + Materialize(x) + + // Loss function: sum of LoRA output (differentiating w.r.t. A and B) + lossFn := func(inputs []*Array) []*Array { + lora.A = inputs[0] + lora.B = inputs[1] + out := lora.Forward(x) + return []*Array{SumAll(out)} + } + + grad := ValueAndGrad(lossFn, 0, 1) // grad w.r.t. A and B + defer grad.Free() + + values, grads, err := grad.Apply(lora.A, lora.B) + if err != nil { + t.Fatalf("ValueAndGrad failed: %v", err) + } + + Materialize(append(values, grads...)...) + + // Loss should be a scalar + loss := values[0].Float() + t.Logf("loss = %f", loss) + + // Gradients should be non-zero (A has random init, B is zero but gets grad) + gradA := grads[0] + gradB := grads[1] + + aGradFloats := gradA.Floats() + bGradFloats := gradB.Floats() + + hasNonZeroA := false + for _, v := range aGradFloats { + if v != 0 { + hasNonZeroA = true + break + } + } + + hasNonZeroB := false + for _, v := range bGradFloats { + if v != 0 { + hasNonZeroB = true + break + } + } + + // A gradient might be zero if B is zero (since dL/dA depends on B) + // But B gradient should be non-zero since A is random + if !hasNonZeroB { + t.Error("gradient for B is all zeros — gradients not flowing") + } + t.Logf("gradA has non-zero: %v, gradB has non-zero: %v", hasNonZeroA, hasNonZeroB) +} + +func TestRandomNormal(t *testing.T) { + arr := RandomNormal(0, 1, []int32{100}, DTypeFloat32) + Materialize(arr) + + floats := arr.Floats() + if len(floats) != 100 { + t.Fatalf("RandomNormal returned %d elements, want 100", len(floats)) + } + + // Check rough statistics: mean should be near 0, values should have spread + var sum float64 + for _, f := range floats { + sum += float64(f) + } + mean := sum / 100 + if math.Abs(mean) > 0.5 { // generous tolerance for 100 samples + t.Errorf("mean = %f, expected near 0", mean) + } +} + +func TestSaveSafetensors(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + b := FromValues([]float32{5, 6, 7, 8, 9, 10}, 3, 2) + Materialize(a, b) + + path := t.TempDir() + "/test.safetensors" + err := SaveSafetensors(path, map[string]*Array{ + "layer.lora_a": a, + "layer.lora_b": b, + }) + if err != nil { + t.Fatalf("SaveSafetensors failed: %v", err) + } + + // Verify file exists + info, err := os.Stat(path) + if err != nil { + t.Fatalf("saved file not found: %v", err) + } + if info.Size() == 0 { + t.Error("saved file is empty") + } + + // Load it back + loaded := LoadAllSafetensors(path) + Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"]) + + aLoaded := loaded["layer.lora_a"].Floats() + bLoaded := loaded["layer.lora_b"].Floats() + + expectedA := []float32{1, 2, 3, 4} + expectedB := []float32{5, 6, 7, 8, 9, 10} + + for i, v := range expectedA { + if aLoaded[i] != v { + t.Errorf("loaded A[%d] = %f, want %f", i, aLoaded[i], v) + } + } + for i, v := range expectedB { + if bLoaded[i] != v { + t.Errorf("loaded B[%d] = %f, want %f", i, bLoaded[i], v) + } + } +} + +func TestLoRAAdapter_Save(t *testing.T) { + w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + adapter := &LoRAAdapter{ + Layers: map[string]*LoRALinear{ + "model.layers.0.self_attn.q_proj": NewLoRALinear(base, 4, 8.0), + }, + Config: DefaultLoRAConfig(), + } + + path := t.TempDir() + "/adapter.safetensors" + err := adapter.Save(path) + if err != nil { + t.Fatalf("Adapter.Save failed: %v", err) + } + + // Load and verify + loaded := LoadAllSafetensors(path) + aKey := "model.layers.0.self_attn.q_proj.lora_a" + bKey := "model.layers.0.self_attn.q_proj.lora_b" + + if _, ok := loaded[aKey]; !ok { + t.Errorf("missing key %s in saved adapter", aKey) + } + if _, ok := loaded[bKey]; !ok { + t.Errorf("missing key %s in saved adapter", bKey) + } +} + +func TestDefaultLoRAConfig(t *testing.T) { + cfg := DefaultLoRAConfig() + if cfg.Rank != 8 { + t.Errorf("Rank = %d, want 8", cfg.Rank) + } + if cfg.Alpha != 16 { + t.Errorf("Alpha = %f, want 16", cfg.Alpha) + } + if len(cfg.TargetKeys) != 2 { + t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys) + } +} diff --git a/mlx.go b/mlx.go new file mode 100644 index 0000000..470e1f7 --- /dev/null +++ b/mlx.go @@ -0,0 +1,115 @@ +//go:build darwin && arm64 + +// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. +// +// Build mlx-c before use: +// +// cd pkg/mlx && go generate ./... +// +// Build (MLX is auto-enabled on darwin/arm64): +// +// go build -o core . +package mlx + +//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release +//go:generate cmake --build build --parallel +//go:generate cmake --install build + +/* +#cgo CXXFLAGS: -std=c++17 +#cgo CFLAGS: -mmacosx-version-min=26.0 +#cgo CPPFLAGS: -I${SRCDIR}/dist/include +#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx +#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate +#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib + +#include +#include +#include "mlx/c/mlx.h" + +static const char *last_mlx_error = NULL; + +static void mlx_go_error_handler(const char *msg, void *data) { + fprintf(stderr, "MLX ERROR: %s\n", msg); + last_mlx_error = msg; +} + +static void set_error_handler() { + mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL); +} + +static const char* get_last_error() { + return last_mlx_error; +} +*/ +import "C" + +import ( + "log/slog" + "sync" +) + +var initOnce sync.Once + +// Init sets up the MLX error handler. Called automatically on first use. +func Init() { + initOnce.Do(func() { + C.set_error_handler() + slog.Debug("mlx: initialized with Metal backend") + }) +} + +// checkError logs the last MLX error if any occurred. +func checkError() { + if msg := C.get_last_error(); msg != nil { + slog.Error("mlx", "error", C.GoString(msg)) + } +} + +// Materialize synchronously evaluates arrays, computing their values on the GPU. +// This is the MLX equivalent of forcing lazy computation to complete. +func Materialize(outputs ...*Array) { + doMaterialize(outputs, false) +} + +// MaterializeAsync queues arrays for asynchronous GPU evaluation. +func MaterializeAsync(outputs ...*Array) { + doMaterialize(outputs, true) +} + +func doMaterialize(outputs []*Array, async bool) { + Init() + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + + for _, output := range outputs { + if output != nil && output.Valid() { + C.mlx_vector_array_append_value(vector, output.ctx) + } + } + + if async { + C.mlx_async_eval(vector) + } else { + C.mlx_eval(vector) + } +} + +// Collect gathers all valid arrays from a variadic list for batch Materialize. +func Collect(arrays ...*Array) []*Array { + var out []*Array + for _, a := range arrays { + if a != nil && a.Valid() { + out = append(out, a) + } + } + return out +} + +// MetalAvailable reports whether Metal GPU is available. +func MetalAvailable() bool { + Init() + var available C.bool + C.mlx_metal_is_available(&available) + return bool(available) +} diff --git a/mlx_stub.go b/mlx_stub.go new file mode 100644 index 0000000..281c8cb --- /dev/null +++ b/mlx_stub.go @@ -0,0 +1,10 @@ +//go:build !(darwin && arm64) + +// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. +// This stub file is used on non-darwin/non-arm64 platforms or when the +// mlx build tag is not set. All operations report MLX as unavailable. +package mlx + +// MetalAvailable reports whether Metal GPU is available. +// Always returns false on non-Apple Silicon platforms. +func MetalAvailable() bool { return false } diff --git a/model/gemma3.go b/model/gemma3.go new file mode 100644 index 0000000..2d0d770 --- /dev/null +++ b/model/gemma3.go @@ -0,0 +1,448 @@ +//go:build darwin && arm64 + +// Package model provides transformer model architectures for MLX inference. +package model + +import ( + "encoding/json" + "fmt" + "log/slog" + "math" + "os" + "path/filepath" + + "forge.lthn.ai/core/go-mlx" + "forge.lthn.ai/core/go-mlx/cache" + "forge.lthn.ai/core/go-mlx/tokenizer" +) + +// TextConfig holds Gemma 3 text model configuration. +type TextConfig struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + RopeLocalBaseFreq float32 `json:"rope_local_base_freq"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + SlidingWindow int32 `json:"sliding_window"` + SlidingWindowPattern int32 `json:"sliding_window_pattern"` + + Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level + Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim) +} + +// GemmaModel is the Gemma 3 text model. +type GemmaModel struct { + EmbedTokens *mlx.Embedding + Layers []*DecoderLayer + Norm *mlx.RMSNormModule + Output *mlx.Linear // Tied to EmbedTokens + + // Precomputed (1 + weight) for Gemma-style RMSNorm + NormScaled *mlx.Array + + Tok *tokenizer.Tokenizer + Cfg *TextConfig +} + +// DecoderLayer is a single transformer block. +type DecoderLayer struct { + InputNorm *mlx.RMSNormModule + Attention *Attention + PostAttnNorm *mlx.RMSNormModule + PreFFNorm *mlx.RMSNormModule + MLP *MLP + PostFFNorm *mlx.RMSNormModule + + // Precomputed scaled weights + InputNormScaled *mlx.Array + PostAttnNormScaled *mlx.Array + PreFFNormScaled *mlx.Array + PostFFNormScaled *mlx.Array + + IsSliding bool + LayerIdx int32 +} + +// Attention implements Gemma 3 attention with Q/K normalization. +type Attention struct { + QProj *mlx.Linear + KProj *mlx.Linear + VProj *mlx.Linear + OProj *mlx.Linear + QNorm *mlx.RMSNormModule + KNorm *mlx.RMSNormModule + + QNormScaled *mlx.Array + KNormScaled *mlx.Array +} + +// MLP is the feed-forward network. +type MLP struct { + GateProj *mlx.Linear + UpProj *mlx.Linear + DownProj *mlx.Linear +} + +// compiledGELU is a singleton for the compiled GELU function. +var compiledGELU *mlx.CompiledFunc + +func getCompiledGELU() *mlx.CompiledFunc { + if compiledGELU == nil { + compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { + return []*mlx.Array{geluApprox(inputs[0])} + }, true) + } + return compiledGELU +} + +// geluApprox computes GELU using the tanh approximation: +// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) +func geluApprox(x *mlx.Array) *mlx.Array { + const sqrt2OverPi = 0.7978845608028654 + const coeff = 0.044715 + + x3 := mlx.Mul(mlx.Mul(x, x), x) + inner := mlx.Add(x, mlx.MulScalar(x3, coeff)) + scaled := mlx.MulScalar(inner, sqrt2OverPi) + t := mlx.Tanh(scaled) + onePlusT := mlx.AddScalar(t, 1.0) + return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT) +} + +// parseConfig handles both flat and nested (text_config) Gemma 3 configs. +func parseConfig(data []byte) (*TextConfig, error) { + // Try parsing text_config from multimodal wrapper + var wrapper struct { + TextConfig TextConfig `json:"text_config"` + ModelType string `json:"model_type"` + Quantization *QuantizationConfig `json:"quantization"` + } + if err := json.Unmarshal(data, &wrapper); err != nil { + return nil, err + } + + cfg := wrapper.TextConfig + + // If text_config was empty, try top-level + if cfg.NumHiddenLayers == 0 { + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + } + + // Quantization is always top-level + cfg.Quantization = wrapper.Quantization + + // Compute scale (head_dim may be inferred later from weights if not in config) + if cfg.HeadDim > 0 { + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + } + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + if cfg.RopeLocalBaseFreq == 0 { + cfg.RopeLocalBaseFreq = 10000 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.SlidingWindowPattern == 0 { + cfg.SlidingWindowPattern = 6 + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 262208 // Gemma 3 default + } + + return &cfg, nil +} + +// LoadGemma3 loads a Gemma 3 text model from a directory. +func LoadGemma3(modelPath string) (*GemmaModel, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("gemma3: load config: %w", err) + } + + cfg, err := parseConfig(data) + if err != nil { + return nil, fmt.Errorf("gemma3: parse config: %w", err) + } + + // Load tokenizer + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("gemma3: load tokenizer: %w", err) + } + + // Load weights from all safetensors files + weights := make(map[string]*mlx.Array) + matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) + for _, path := range matches { + for name, arr := range mlx.LoadSafetensors(path) { + weights[name] = arr + } + } + + // Helper to resolve weight with language_model. prefix fallback + w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + + // Infer head_dim from q_proj weight shape when not in config. + // Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads. + if cfg.HeadDim == 0 { + qWeight := w("model.layers.0.self_attn.q_proj.weight") + if qWeight != nil { + qShape := qWeight.Shape() + if len(qShape) > 0 { + cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim) + } + } + } + + // Helper to create linear layer (quantized or dense) + q := cfg.Quantization + if q != nil { + slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) + } + linear := func(prefix string) *mlx.Linear { + weight := w(prefix + ".weight") + scales := w(prefix + ".scales") + biases := w(prefix + ".biases") + if scales != nil && q != nil { + return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits) + } + return mlx.NewLinear(weight, nil) + } + + // Create embedding (quantized or dense) + embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { + embed.Scales = embedScales + embed.Biases = w("model.embed_tokens.biases") + embed.GroupSize = q.GroupSize + embed.Bits = q.Bits + } + + m := &GemmaModel{ + EmbedTokens: embed, + Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), + Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, + Tok: tok, + Cfg: cfg, + } + + // Initialize layers + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + prefix := fmt.Sprintf("model.layers.%d", i) + m.Layers[i] = &DecoderLayer{ + InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")}, + PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")}, + PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, + PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, + Attention: &Attention{ + QProj: linear(prefix + ".self_attn.q_proj"), + KProj: linear(prefix + ".self_attn.k_proj"), + VProj: linear(prefix + ".self_attn.v_proj"), + OProj: linear(prefix + ".self_attn.o_proj"), + QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, + KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, + }, + MLP: &MLP{ + GateProj: linear(prefix + ".mlp.gate_proj"), + UpProj: linear(prefix + ".mlp.up_proj"), + DownProj: linear(prefix + ".mlp.down_proj"), + }, + LayerIdx: i, + IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), + } + } + + // Output head — check for separate lm_head first, else tie to embeddings + lmHeadWeight := w("lm_head.weight") + if lmHeadWeight != nil { + lmHeadScales := w("lm_head.scales") + if lmHeadScales != nil && q != nil { + m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + } else { + m.Output = mlx.NewLinear(lmHeadWeight, nil) + } + } else { + // Tied embeddings — reuse embed_tokens weights (with quantization if present) + m.Output = m.EmbedTokens.AsLinear() + } + + // Materialize all weights + var allArrays []*mlx.Array + for _, a := range weights { + allArrays = append(allArrays, a) + } + mlx.Materialize(allArrays...) + + // Precompute (1 + weight) for Gemma-style RMSNorm + precomputeScaledWeights(m) + + return m, nil +} + +func precomputeScaledWeights(m *GemmaModel) { + m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) + + for _, layer := range m.Layers { + layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) + layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) + layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) + layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) + layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) + layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) + } + + var scaled []*mlx.Array + scaled = append(scaled, m.NormScaled) + for _, layer := range m.Layers { + scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled, + layer.PreFFNormScaled, layer.PostFFNormScaled, + layer.Attention.QNormScaled, layer.Attention.KNormScaled) + } + mlx.Materialize(scaled...) +} + +func isLayerSliding(layerIdx, pattern int32) bool { + if pattern <= 0 { + return false + } + return (layerIdx+1)%pattern != 0 +} + +// Forward runs the text model forward pass. +func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + shape := tokens.Shape() + B, L := shape[0], shape[1] + + h := m.EmbedTokens.Forward(tokens) + h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) + + for i, layer := range m.Layers { + h = layer.forward(h, caches[i], B, L, m.Cfg) + } + + return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) +} + +func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { + normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) + attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg) + attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + h := mlx.Add(x, attnOut) + + normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) + mlpOut := l.MLP.forward(normed) + mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + return mlx.Add(h, mlpOut) +} + +func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape to [B, num_heads, L, head_dim] + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + // Q/K normalization + q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) + k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) + + // RoPE with appropriate theta + ropeTheta := cfg.RopeTheta + if isSliding { + ropeTheta = cfg.RopeLocalBaseFreq + } + q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + + // Update cache + k, v = c.Update(k, v, int(L)) + + // GQA: repeat K/V heads + repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + if repeatFactor > 1 { + k = mlx.RepeatKV(k, repeatFactor) + v = mlx.RepeatKV(v, repeatFactor) + } + + // Scaled dot-product attention + out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +func (m *MLP) forward(x *mlx.Array) *mlx.Array { + gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0] + return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) +} + +// NewCache creates per-layer caches for generation. +func (m *GemmaModel) NewCache() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + if m.Layers[i].IsSliding { + caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow)) + } else { + caches[i] = cache.NewKVCache() + } + } + return caches +} + +// NumLayers returns the number of transformer layers. +func (m *GemmaModel) NumLayers() int { return len(m.Layers) } + +// Tokenizer returns the model's tokenizer. +func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } + +// ModelType returns the architecture identifier. +func (m *GemmaModel) ModelType() string { return "gemma3" } + +// ApplyLoRA wraps target projection layers with LoRA adapters. +func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { + adapter := &mlx.LoRAAdapter{ + Layers: make(map[string]*mlx.LoRALinear), + Config: cfg, + } + + for i, layer := range m.Layers { + prefix := fmt.Sprintf("model.layers.%d.self_attn", i) + for _, target := range cfg.TargetKeys { + var proj *mlx.Linear + switch target { + case "q_proj": + proj = layer.Attention.QProj + case "k_proj": + proj = layer.Attention.KProj + case "v_proj": + proj = layer.Attention.VProj + case "o_proj": + proj = layer.Attention.OProj + } + if proj != nil { + lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + proj.LoRA = lora + adapter.Layers[prefix+"."+target] = lora + } + } + } + + return adapter +} diff --git a/model/model.go b/model/model.go new file mode 100644 index 0000000..408d8d1 --- /dev/null +++ b/model/model.go @@ -0,0 +1,78 @@ +//go:build darwin && arm64 + +// Package model provides transformer model architectures for MLX inference. +package model + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "forge.lthn.ai/core/go-mlx" + "forge.lthn.ai/core/go-mlx/cache" + "forge.lthn.ai/core/go-mlx/tokenizer" +) + +// Model is the common interface for all transformer model architectures. +type Model interface { + // Forward runs the model forward pass on token IDs with KV caches. + Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array + + // NewCache creates per-layer KV caches for generation. + NewCache() []cache.Cache + + // NumLayers returns the number of transformer layers. + NumLayers() int + + // Tokenizer returns the model's tokenizer. + Tokenizer() *tokenizer.Tokenizer + + // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). + ModelType() string + + // ApplyLoRA wraps target projection layers with LoRA adapters for training. + // Returns the adapter which holds references to all LoRA layers. + ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter +} + +// QuantizationConfig holds quantization parameters from config.json. +type QuantizationConfig struct { + GroupSize int `json:"group_size"` + Bits int `json:"bits"` +} + +// resolveWeight looks up a weight with optional "language_model." prefix. +func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { + if w, ok := weights[name]; ok { + return w + } + if w, ok := weights["language_model."+name]; ok { + return w + } + return nil +} + +// LoadModel auto-detects the model architecture from config.json and loads it. +func LoadModel(modelPath string) (Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("model: load config: %w", err) + } + + var probe struct { + ModelType string `json:"model_type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, fmt.Errorf("model: parse model_type: %w", err) + } + + switch probe.ModelType { + case "qwen3": + return LoadQwen3(modelPath) + case "gemma3", "gemma2": + return LoadGemma3(modelPath) + default: + return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType) + } +} diff --git a/model/qwen3.go b/model/qwen3.go new file mode 100644 index 0000000..0a7ca13 --- /dev/null +++ b/model/qwen3.go @@ -0,0 +1,337 @@ +//go:build darwin && arm64 + +package model + +import ( + "encoding/json" + "fmt" + "log/slog" + "math" + "os" + "path/filepath" + + "forge.lthn.ai/core/go-mlx" + "forge.lthn.ai/core/go-mlx/cache" + "forge.lthn.ai/core/go-mlx/tokenizer" +) + +// Qwen3Config holds Qwen 3 model configuration. +type Qwen3Config struct { + HiddenSize int32 `json:"hidden_size"` + NumHiddenLayers int32 `json:"num_hidden_layers"` + IntermediateSize int32 `json:"intermediate_size"` + NumAttentionHeads int32 `json:"num_attention_heads"` + NumKeyValueHeads int32 `json:"num_key_value_heads"` + HeadDim int32 `json:"head_dim"` + VocabSize int32 `json:"vocab_size"` + RMSNormEps float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + MaxPositionEmbeddings int32 `json:"max_position_embeddings"` + + Quantization *QuantizationConfig `json:"-"` + Scale float32 `json:"-"` // 1/sqrt(head_dim) +} + +// Qwen3Model is the Qwen 3 text model. +type Qwen3Model struct { + EmbedTokens *mlx.Embedding + Layers []*Qwen3DecoderLayer + Norm *mlx.RMSNormModule + Output *mlx.Linear + + Tok *tokenizer.Tokenizer + Cfg *Qwen3Config +} + +// Qwen3DecoderLayer is a single transformer block. +// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add. +type Qwen3DecoderLayer struct { + InputNorm *mlx.RMSNormModule // Pre-attention norm + PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) + Attention *Qwen3Attention + MLP *Qwen3MLP +} + +// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization. +type Qwen3Attention struct { + QProj *mlx.Linear + KProj *mlx.Linear + VProj *mlx.Linear + OProj *mlx.Linear + QNorm *mlx.RMSNormModule + KNorm *mlx.RMSNormModule +} + +// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)). +type Qwen3MLP struct { + GateProj *mlx.Linear + UpProj *mlx.Linear + DownProj *mlx.Linear +} + +func parseQwen3Config(data []byte) (*Qwen3Config, error) { + var cfg Qwen3Config + if err := json.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + // Top-level quantization + var wrapper struct { + Quantization *QuantizationConfig `json:"quantization"` + } + json.Unmarshal(data, &wrapper) + cfg.Quantization = wrapper.Quantization + + // Compute scale + if cfg.HeadDim == 0 { + cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads + } + cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) + + // Defaults + if cfg.RopeTheta == 0 { + cfg.RopeTheta = 1000000 + } + if cfg.RMSNormEps == 0 { + cfg.RMSNormEps = 1e-6 + } + if cfg.VocabSize == 0 { + cfg.VocabSize = 151936 + } + + return &cfg, nil +} + +// LoadQwen3 loads a Qwen 3 model from a safetensors directory. +func LoadQwen3(modelPath string) (*Qwen3Model, error) { + data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) + if err != nil { + return nil, fmt.Errorf("qwen3: load config: %w", err) + } + + cfg, err := parseQwen3Config(data) + if err != nil { + return nil, fmt.Errorf("qwen3: parse config: %w", err) + } + + tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) + if err != nil { + return nil, fmt.Errorf("qwen3: load tokenizer: %w", err) + } + + // Load weights from all safetensors files + weights := make(map[string]*mlx.Array) + matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) + for _, path := range matches { + for name, arr := range mlx.LoadSafetensors(path) { + weights[name] = arr + } + } + + w := func(name string) *mlx.Array { return resolveWeight(weights, name) } + + // Quantization setup + q := cfg.Quantization + if q != nil { + slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) + } + linear := func(prefix string) *mlx.Linear { + weight := w(prefix + ".weight") + scales := w(prefix + ".scales") + biases := w(prefix + ".biases") + bias := w(prefix + ".bias") + if scales != nil && q != nil { + return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits) + } + return mlx.NewLinear(weight, bias) + } + + // Embedding + embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} + if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { + embed.Scales = embedScales + embed.Biases = w("model.embed_tokens.biases") + embed.GroupSize = q.GroupSize + embed.Bits = q.Bits + } + + m := &Qwen3Model{ + EmbedTokens: embed, + Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers), + Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, + Tok: tok, + Cfg: cfg, + } + + for i := int32(0); i < cfg.NumHiddenLayers; i++ { + p := fmt.Sprintf("model.layers.%d", i) + m.Layers[i] = &Qwen3DecoderLayer{ + InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")}, + PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")}, + Attention: &Qwen3Attention{ + QProj: linear(p + ".self_attn.q_proj"), + KProj: linear(p + ".self_attn.k_proj"), + VProj: linear(p + ".self_attn.v_proj"), + OProj: linear(p + ".self_attn.o_proj"), + QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")}, + KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")}, + }, + MLP: &Qwen3MLP{ + GateProj: linear(p + ".mlp.gate_proj"), + UpProj: linear(p + ".mlp.up_proj"), + DownProj: linear(p + ".mlp.down_proj"), + }, + } + } + + // Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate + lmHeadWeight := w("lm_head.weight") + if lmHeadWeight != nil { + lmHeadScales := w("lm_head.scales") + if lmHeadScales != nil && q != nil { + m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) + } else { + m.Output = mlx.NewLinear(lmHeadWeight, nil) + } + } else { + m.Output = m.EmbedTokens.AsLinear() + } + + // Materialise all weights onto Metal + var allArrays []*mlx.Array + for _, a := range weights { + allArrays = append(allArrays, a) + } + mlx.Materialize(allArrays...) + + slog.Info("qwen3: model loaded", + "layers", cfg.NumHiddenLayers, + "hidden", cfg.HiddenSize, + "heads", cfg.NumAttentionHeads, + "kv_heads", cfg.NumKeyValueHeads, + "head_dim", cfg.HeadDim, + "vocab", cfg.VocabSize, + ) + + return m, nil +} + +// Forward runs the Qwen 3 forward pass. +// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size). +func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { + shape := tokens.Shape() + B, L := shape[0], shape[1] + + h := m.EmbedTokens.Forward(tokens) + + for i, layer := range m.Layers { + h = layer.forward(h, caches[i], B, L, m.Cfg) + } + + return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) +} + +func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { + // Pre-attention norm → attention → residual add + normed := l.InputNorm.Forward(x, cfg.RMSNormEps) + attnOut := l.Attention.forward(normed, c, B, L, cfg) + h := mlx.Add(x, attnOut) + + // Pre-MLP norm → MLP → residual add + normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps) + mlpOut := l.MLP.forward(normed) + return mlx.Add(h, mlpOut) +} + +func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { + q := a.QProj.Forward(x) + k := a.KProj.Forward(x) + v := a.VProj.Forward(x) + + // Reshape to [B, num_heads, L, head_dim] + q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) + k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + + // Q/K RMS normalization (Qwen 3 has this) + q = a.QNorm.Forward(q, cfg.RMSNormEps) + k = a.KNorm.Forward(k, cfg.RMSNormEps) + + // RoPE — single theta for all layers (no sliding window) + q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + + // Update KV cache + k, v = c.Update(k, v, int(L)) + + // GQA: repeat K/V heads to match Q heads + repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + if repeatFactor > 1 { + k = mlx.RepeatKV(k, repeatFactor) + v = mlx.RepeatKV(v, repeatFactor) + } + + // Scaled dot-product attention + out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) + return a.OProj.Forward(out) +} + +// forward computes SwiGLU: down(silu(gate(x)) * up(x)). +func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array { + gate := mlx.SiLU(m.GateProj.Forward(x)) + return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) +} + +// NewCache creates per-layer KV caches. Qwen 3 uses global attention only. +func (m *Qwen3Model) NewCache() []cache.Cache { + caches := make([]cache.Cache, len(m.Layers)) + for i := range caches { + caches[i] = cache.NewKVCache() + } + return caches +} + +// NumLayers returns the number of transformer layers. +func (m *Qwen3Model) NumLayers() int { return len(m.Layers) } + +// Tokenizer returns the model's tokenizer. +func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok } + +// ModelType returns the architecture identifier. +func (m *Qwen3Model) ModelType() string { return "qwen3" } + +// ApplyLoRA wraps target projection layers with LoRA adapters. +func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { + adapter := &mlx.LoRAAdapter{ + Layers: make(map[string]*mlx.LoRALinear), + Config: cfg, + } + + for i, layer := range m.Layers { + prefix := fmt.Sprintf("model.layers.%d.self_attn", i) + for _, target := range cfg.TargetKeys { + var proj *mlx.Linear + switch target { + case "q_proj": + proj = layer.Attention.QProj + case "k_proj": + proj = layer.Attention.KProj + case "v_proj": + proj = layer.Attention.VProj + case "o_proj": + proj = layer.Attention.OProj + } + if proj != nil { + lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) + proj.LoRA = lora + adapter.Layers[prefix+"."+target] = lora + } + } + } + + return adapter +} diff --git a/nn.go b/nn.go new file mode 100644 index 0000000..4de9db9 --- /dev/null +++ b/nn.go @@ -0,0 +1,115 @@ +//go:build darwin && arm64 + +package mlx + +// Linear is a fully-connected layer: y = x @ W.T + bias. +// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul. +// Set LoRA to inject a low-rank adapter (training only). +type Linear struct { + Weight *Array `weight:"weight"` + Scales *Array `weight:"scales"` + Biases *Array `weight:"biases"` + Bias *Array `weight:"bias"` + GroupSize int + Bits int + + LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it +} + +// NewLinear creates a dense Linear layer with optional bias. +func NewLinear(weight, bias *Array) *Linear { + return &Linear{Weight: weight, Bias: bias} +} + +// NewQuantizedLinear creates a quantized Linear layer. +func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear { + return &Linear{ + Weight: weight, + Scales: scales, + Biases: biases, + Bias: bias, + GroupSize: groupSize, + Bits: bits, + } +} + +// Forward computes the linear transformation. +// If a LoRA adapter is attached, routes through it instead (base + low-rank delta). +// Uses QuantizedMatmul when quantization parameters are present. +func (l *Linear) Forward(x *Array) *Array { + if l.LoRA != nil { + return l.LoRA.Forward(x) + } + return l.baseForward(x) +} + +// baseForward is the raw linear transformation without LoRA. +// Used internally by LoRALinear to avoid infinite recursion. +func (l *Linear) baseForward(x *Array) *Array { + var out *Array + if l.Scales != nil { + out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits) + } else { + out = Matmul(x, Transpose(l.Weight)) + } + if l.Bias != nil && l.Bias.Valid() { + out = Add(out, l.Bias) + } + return out +} + +// Embedding is a lookup table for token embeddings. +// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup. +type Embedding struct { + Weight *Array `weight:"weight"` + Scales *Array `weight:"scales"` + Biases *Array `weight:"biases"` + GroupSize int + Bits int +} + +// Forward looks up embeddings for the given token indices. +func (e *Embedding) Forward(indices *Array) *Array { + if e.Scales != nil { + w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits) + return Take(w, indices, 0) + } + return Take(e.Weight, indices, 0) +} + +// AsLinear returns a Linear layer using the embedding weights (for tied output). +func (e *Embedding) AsLinear() *Linear { + return &Linear{ + Weight: e.Weight, + Scales: e.Scales, + Biases: e.Biases, + GroupSize: e.GroupSize, + Bits: e.Bits, + } +} + +// RMSNormModule is an RMS normalization layer wrapping the fused kernel. +type RMSNormModule struct { + Weight *Array `weight:"weight"` +} + +// Forward applies RMS normalization. +func (r *RMSNormModule) Forward(x *Array, eps float32) *Array { + return RMSNorm(x, r.Weight, eps) +} + +// RepeatKV repeats key/value heads for grouped-query attention. +// Input shape: [B, num_kv_heads, L, D] +// Output shape: [B, num_kv_heads * factor, L, D] +func RepeatKV(x *Array, factor int32) *Array { + if factor <= 1 { + return x + } + shape := x.Shape() + B, H, L, D := shape[0], shape[1], shape[2], shape[3] + + // Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D] + expanded := ExpandDims(x, 2) + expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D}) + return Reshape(expanded, B, H*factor, L, D) +} diff --git a/ops.go b/ops.go new file mode 100644 index 0000000..99f953b --- /dev/null +++ b/ops.go @@ -0,0 +1,353 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include +#include "mlx/c/mlx.h" +*/ +import "C" + +import "unsafe" + +// --- Element-wise arithmetic --- + +// Add returns element-wise a + b. +func Add(a, b *Array) *Array { + out := New("ADD", a, b) + C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// AddScalar returns a + scalar (broadcast). +func AddScalar(a *Array, s float32) *Array { + scalar := FromValue(s) + return Add(a, scalar) +} + +// Mul returns element-wise a * b. +func Mul(a, b *Array) *Array { + out := New("MUL", a, b) + C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// MulScalar returns a * scalar (broadcast). +func MulScalar(a *Array, s float32) *Array { + scalar := FromValue(s) + return Mul(a, scalar) +} + +// Divide returns element-wise a / b. +func Divide(a, b *Array) *Array { + out := New("DIV", a, b) + C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Subtract returns element-wise a - b. +func Subtract(a, b *Array) *Array { + out := New("SUB", a, b) + C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Negative returns element-wise -a. +func Negative(a *Array) *Array { + out := New("NEG", a) + C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// --- Math functions --- + +// Exp returns element-wise exp(a). +func Exp(a *Array) *Array { + out := New("EXP", a) + C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Sigmoid returns element-wise 1/(1+exp(-a)). +func Sigmoid(a *Array) *Array { + out := New("SIGMOID", a) + C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// SiLU returns element-wise x * sigmoid(x) (Swish activation). +func SiLU(a *Array) *Array { + return Mul(a, Sigmoid(a)) +} + +// Tanh returns element-wise tanh(a). +func Tanh(a *Array) *Array { + out := New("TANH", a) + C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Sqrt returns element-wise sqrt(a). +func Sqrt(a *Array) *Array { + out := New("SQRT", a) + C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Rsqrt returns element-wise 1/sqrt(a). +func Rsqrt(a *Array) *Array { + out := New("RSQRT", a) + C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Reciprocal returns element-wise 1/a. +func Reciprocal(a *Array) *Array { + out := New("RECIPROCAL", a) + C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Square returns element-wise a^2. +func Square(a *Array) *Array { + out := New("SQUARE", a) + C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + +// Power returns element-wise a^b. +func Power(a, b *Array) *Array { + out := New("POWER", a, b) + C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Maximum returns element-wise max(a, b). +func Maximum(a, b *Array) *Array { + out := New("MAX", a, b) + C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Minimum returns element-wise min(a, b). +func Minimum(a, b *Array) *Array { + out := New("MIN", a, b) + C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// --- Matrix operations --- + +// Matmul returns the matrix product of a and b. +func Matmul(a, b *Array) *Array { + out := New("MATMUL", a, b) + C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// QuantizedMatmul performs quantized matrix multiplication. +func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array { + out := New("QMATMUL", x, w, scales, biases) + gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} + b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} + mode := C.CString("affine") + defer C.free(unsafe.Pointer(mode)) + C.mlx_quantized_matmul( + &out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx, + C._Bool(transpose), gs, b, mode, + DefaultStream().ctx, + ) + return out +} + +// --- Reductions --- + +// Softmax returns softmax along the last axis. +func Softmax(a *Array) *Array { + out := New("SOFTMAX", a) + axis := []C.int{C.int(-1)} + C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx) + return out +} + +// Argmax returns the index of the maximum value along an axis. +func Argmax(a *Array, axis int, keepDims bool) *Array { + out := New("ARGMAX", a) + C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) + return out +} + +// TopK returns the top k values along the last axis. +func TopK(a *Array, k int) *Array { + out := New("TOPK", a) + C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx) + return out +} + +// Sum reduces by summation along the given axis. +func Sum(a *Array, axis int, keepDims bool) *Array { + out := New("SUM", a) + axes := []C.int{C.int(axis)} + C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) + return out +} + +// Mean reduces by averaging along the given axis. +func Mean(a *Array, axis int, keepDims bool) *Array { + out := New("MEAN", a) + axes := []C.int{C.int(axis)} + C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) + return out +} + +// --- Shape operations --- + +// Reshape changes the shape of an array. +func Reshape(a *Array, shape ...int32) *Array { + out := New("RESHAPE", a) + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) + return out +} + +// Transpose permutes dimensions. If no axes given, reverses all dims. +func Transpose(a *Array, axes ...int) *Array { + out := New("TRANSPOSE", a) + if len(axes) == 0 { + C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx) + } else { + cAxes := make([]C.int, len(axes)) + for i, ax := range axes { + cAxes[i] = C.int(ax) + } + C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) + } + return out +} + +// ExpandDims inserts a new axis at the given position. +func ExpandDims(a *Array, axis int) *Array { + out := New("EXPAND_DIMS", a) + C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// Squeeze removes dimensions of size 1. +func Squeeze(a *Array, axes ...int) *Array { + out := New("SQUEEZE", a) + cAxes := make([]C.int, len(axes)) + for i, ax := range axes { + cAxes[i] = C.int(ax) + } + C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) + return out +} + +// Concatenate joins arrays along the given axis. +func Concatenate(arrays []*Array, axis int) *Array { + vector := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vector) + + inputs := make([]*Array, len(arrays)) + for i, a := range arrays { + C.mlx_vector_array_append_value(vector, a.ctx) + inputs[i] = a + } + + out := New("CONCAT", inputs...) + C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) + return out +} + +// BroadcastTo broadcasts an array to the given shape. +func BroadcastTo(a *Array, shape []int32) *Array { + out := New("BROADCAST", a) + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) + return out +} + +// AsType casts an array to a different dtype. +func AsType(a *Array, dtype DType) *Array { + out := New("ASTYPE", a) + C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) + return out +} + +// AsStrided creates a view with custom strides. +func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { + out := New("AS_STRIDED", a) + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + cStrides := make([]C.int64_t, len(strides)) + for i, s := range strides { + cStrides[i] = C.int64_t(s) + } + C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx) + return out +} + +// Take gathers elements from a along axis using indices. +func Take(a, indices *Array, axis int) *Array { + out := New("TAKE", a, indices) + C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// Where selects elements from a or b based on condition. +func Where(condition, a, b *Array) *Array { + out := New("WHERE", condition, a, b) + C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) + return out +} + +// Argpartition partially sorts and returns indices for top-k selection. +func Argpartition(a *Array, kth, axis int) *Array { + out := New("ARGPARTITION", a) + C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) + return out +} + +// Dequantize restores a quantized array to full precision. +func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { + out := New("DEQUANTIZE", w, scales, biases) + gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} + b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} + mode := C.CString("affine") + defer C.free(unsafe.Pointer(mode)) + noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)} + C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx) + return out +} + +// PutAlongAxis places values into array at indices along axis. +func PutAlongAxis(a, indices, values *Array, axis int) *Array { + out := New("PUT_ALONG_AXIS", a, indices, values) + // Use scatter approach: src[indices] = values + C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// TakeAlongAxis gathers elements from a along axis using indices. +// Unlike Take, this uses the same number of dimensions for indices and input. +func TakeAlongAxis(a, indices *Array, axis int) *Array { + out := New("TAKE_ALONG_AXIS", a, indices) + C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) + return out +} + +// LogSumExp computes log(sum(exp(a))) along the given axis. +// Numerically stable reduction for cross-entropy loss. +func LogSumExp(a *Array, axis int, keepDims bool) *Array { + out := New("LOGSUMEXP", a) + C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) + return out +} diff --git a/optim.go b/optim.go new file mode 100644 index 0000000..f54bcf5 --- /dev/null +++ b/optim.go @@ -0,0 +1,106 @@ +//go:build darwin && arm64 + +package mlx + +import "math" + +// AdamW implements the AdamW optimiser (Adam with decoupled weight decay). +// +// Update rule per parameter: +// +// m = beta1 * m + (1 - beta1) * grad +// v = beta2 * v + (1 - beta2) * grad^2 +// m_hat = m / (1 - beta1^t) +// v_hat = v / (1 - beta2^t) +// param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps) +type AdamW struct { + LR float64 // Learning rate (default 1e-5) + Beta1 float64 // First moment decay (default 0.9) + Beta2 float64 // Second moment decay (default 0.999) + Eps float64 // Numerical stability (default 1e-8) + WeightDecay float64 // Decoupled weight decay (default 0.01) + + step int // Number of updates performed + m []*Array // First moment estimates (positional, parallel to params) + v []*Array // Second moment estimates (positional, parallel to params) +} + +// NewAdamW creates an AdamW optimiser with default hyperparameters. +func NewAdamW(lr float64) *AdamW { + return &AdamW{ + LR: lr, + Beta1: 0.9, + Beta2: 0.999, + Eps: 1e-8, + WeightDecay: 0.01, + } +} + +// Step performs one optimisation step: updates params using gradients. +// params and grads must be parallel slices of the same length. +// Returns the updated parameter arrays (params are replaced in-place). +func (o *AdamW) Step(params []*Array, grads []*Array) []*Array { + o.step++ + + // Bias correction factors + bc1 := 1.0 - math.Pow(o.Beta1, float64(o.step)) + bc2 := 1.0 - math.Pow(o.Beta2, float64(o.step)) + + updated := make([]*Array, len(params)) + + // Grow moment slices if needed (first call or param count increased) + for len(o.m) < len(params) { + o.m = append(o.m, nil) + o.v = append(o.v, nil) + } + + for i, param := range params { + grad := grads[i] + + // Initialise moments on first use + if o.m[i] == nil { + shape := param.Shape() + o.m[i] = Zeros(shape, param.Dtype()) + o.v[i] = Zeros(shape, param.Dtype()) + } + + // m = beta1 * m + (1 - beta1) * grad + m := Add( + MulScalar(o.m[i], float32(o.Beta1)), + MulScalar(grad, float32(1.0-o.Beta1)), + ) + + // v = beta2 * v + (1 - beta2) * grad^2 + v := Add( + MulScalar(o.v[i], float32(o.Beta2)), + MulScalar(Square(grad), float32(1.0-o.Beta2)), + ) + + // Bias-corrected estimates + mHat := MulScalar(m, float32(1.0/bc1)) + vHat := MulScalar(v, float32(1.0/bc2)) + + // Weight decay: param = param * (1 - lr * weight_decay) + decayed := MulScalar(param, float32(1.0-o.LR*o.WeightDecay)) + + // Update: param = decayed - lr * m_hat / (sqrt(v_hat) + eps) + denom := AddScalar(Sqrt(vHat), float32(o.Eps)) + step := MulScalar(Divide(mHat, denom), float32(o.LR)) + newParam := Subtract(decayed, step) + + // Store updated moments + o.m[i] = m + o.v[i] = v + + updated[i] = newParam + } + + return updated +} + +// Reset clears the optimiser state (moments and step counter). +func (o *AdamW) Reset() { + o.step = 0 + o.m = nil + o.v = nil +} diff --git a/optim_test.go b/optim_test.go new file mode 100644 index 0000000..b2df911 --- /dev/null +++ b/optim_test.go @@ -0,0 +1,175 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +func TestAdamW_BasicStep(t *testing.T) { + // Simple test: minimise f(x) = x^2, starting at x=10 + x := FromValue(float32(10.0)) + Materialize(x) + + opt := NewAdamW(0.1) + + for i := 0; i < 300; i++ { + // Gradient of x^2 is 2x + lossFn := func(inputs []*Array) []*Array { + p := inputs[0] + return []*Array{Mul(p, p)} + } + + grad := ValueAndGrad(lossFn) + _, grads, err := grad.Apply(x) + grad.Free() + if err != nil { + t.Fatalf("step %d: grad failed: %v", i, err) + } + + updated := opt.Step([]*Array{x}, grads) + x = updated[0] + Materialize(x) + } + + final := x.Float() + if math.Abs(final) > 0.5 { + t.Errorf("after 300 steps, x = %f, want near 0", final) + } + t.Logf("final x = %f (started at 10.0)", final) +} + +func TestAdamW_MultiParam(t *testing.T) { + // Minimise f(x, y) = x^2 + y^2 + x := FromValue(float32(5.0)) + y := FromValue(float32(-3.0)) + Materialize(x, y) + + opt := NewAdamW(0.1) + + for i := 0; i < 100; i++ { + lossFn := func(inputs []*Array) []*Array { + return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))} + } + + grad := ValueAndGrad(lossFn, 0, 1) + _, grads, err := grad.Apply(x, y) + grad.Free() + if err != nil { + t.Fatalf("step %d failed: %v", i, err) + } + + updated := opt.Step([]*Array{x, y}, grads) + x = updated[0] + y = updated[1] + Materialize(x, y) + } + + xFinal := x.Float() + yFinal := y.Float() + if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 { + t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal) + } + t.Logf("final x=%f, y=%f", xFinal, yFinal) +} + +func TestAdamW_WeightDecay(t *testing.T) { + // With large weight decay and zero gradient, param should decay toward 0 + x := FromValue(float32(10.0)) + Materialize(x) + + opt := NewAdamW(0.01) + opt.WeightDecay = 0.5 // aggressive decay + + zeroGrad := FromValue(float32(0.0)) + Materialize(zeroGrad) + + for i := 0; i < 10; i++ { + updated := opt.Step([]*Array{x}, []*Array{zeroGrad}) + x = updated[0] + Materialize(x) + } + + final := x.Float() + if final >= 10.0 { + t.Errorf("x = %f, should have decayed from 10.0", final) + } + if final <= 0 { + t.Errorf("x = %f, decayed too much", final) + } + t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final) +} + +func TestAdamW_Reset(t *testing.T) { + opt := NewAdamW(0.01) + + x := FromValue(float32(5.0)) + grad := FromValue(float32(1.0)) + Materialize(x, grad) + + opt.Step([]*Array{x}, []*Array{grad}) + if opt.step != 1 { + t.Errorf("step = %d, want 1", opt.step) + } + + opt.Reset() + if opt.step != 0 { + t.Errorf("after reset, step = %d, want 0", opt.step) + } + if opt.m != nil { + t.Error("after reset, moments should be nil") + } +} + +func TestAdamW_WithLoRA(t *testing.T) { + // End-to-end: create LoRA layer, compute gradients, update with AdamW + w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) + Materialize(w) + base := NewLinear(w, nil) + + lora := NewLoRALinear(base, 4, 8.0) + opt := NewAdamW(0.001) + + x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32) + target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32) + Materialize(x, target) + + var initialLoss, finalLoss float64 + + for step := 0; step < 50; step++ { + lossFn := func(inputs []*Array) []*Array { + lora.A = inputs[0] + lora.B = inputs[1] + pred := lora.Forward(x) + return []*Array{MSELoss(pred, target)} + } + + grad := ValueAndGrad(lossFn, 0, 1) + values, grads, err := grad.Apply(lora.A, lora.B) + grad.Free() + if err != nil { + t.Fatalf("step %d failed: %v", step, err) + } + + Materialize(append(values, grads...)...) + + loss := values[0].Float() + if step == 0 { + initialLoss = loss + } + if step == 49 { + finalLoss = loss + } + + updated := opt.Step([]*Array{lora.A, lora.B}, grads) + lora.A = updated[0] + lora.B = updated[1] + Materialize(lora.A, lora.B) + } + + t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss) + if finalLoss >= initialLoss { + t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss) + } +} diff --git a/random.go b/random.go new file mode 100644 index 0000000..f7e09b2 --- /dev/null +++ b/random.go @@ -0,0 +1,46 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" +*/ +import "C" + +// RandomCategorical samples from a categorical distribution defined by logprobs. +// Returns indices sampled according to the log-probability distribution along the last axis. +func RandomCategorical(logprobs *Array) *Array { + out := New("RANDOM_CATEGORICAL", logprobs) + key := C.mlx_array_new() + defer C.mlx_array_free(key) + C.mlx_random_categorical( + &out.ctx, + logprobs.ctx, + C.int(-1), // axis + key, // null key = use default RNG + DefaultStream().ctx, + ) + return out +} + +// RandomUniform generates uniform random values in [low, high). +func RandomUniform(low, high float32, shape []int32, dtype DType) *Array { + out := New("RANDOM_UNIFORM") + cShape := make([]C.int, len(shape)) + for i, s := range shape { + cShape[i] = C.int(s) + } + lo := FromValue(low) + hi := FromValue(high) + key := C.mlx_array_new() + defer C.mlx_array_free(key) + C.mlx_random_uniform( + &out.ctx, + lo.ctx, hi.ctx, + &cShape[0], C.size_t(len(cShape)), + C.mlx_dtype(dtype), + key, + DefaultStream().ctx, + ) + return out +} diff --git a/sample/sample.go b/sample/sample.go new file mode 100644 index 0000000..4a6b599 --- /dev/null +++ b/sample/sample.go @@ -0,0 +1,90 @@ +//go:build darwin && arm64 + +// Package sample provides composable token sampling strategies. +package sample + +import ( + "math" + + "forge.lthn.ai/core/go-mlx" +) + +// Sampler transforms logits into a sampled token index. +type Sampler interface { + Sample(logits *mlx.Array) *mlx.Array +} + +// New creates a composable sampler chain from the given parameters. +// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample. +func New(temp, topP, minP float32, topK int) Sampler { + if temp == 0 { + return greedy{} + } + + var samplers []Sampler + if topP > 0 && topP < 1 { + samplers = append(samplers, TopP(topP)) + } + if minP > 0 { + samplers = append(samplers, MinPSampler(minP)) + } + if topK > 0 { + samplers = append(samplers, TopKSampler(topK)) + } + samplers = append(samplers, Temperature(temp)) + return chain(samplers) +} + +// chain applies a sequence of samplers, then samples from the result. +type chain []Sampler + +func (c chain) Sample(logits *mlx.Array) *mlx.Array { + for _, s := range c { + logits = s.Sample(logits) + } + // Final categorical sample from log-probabilities + return mlx.RandomCategorical(logits) +} + +// greedy returns the argmax token. +type greedy struct{} + +func (greedy) Sample(logits *mlx.Array) *mlx.Array { + return mlx.Argmax(logits, -1, false) +} + +// Temperature scales logits by 1/temp. +type Temperature float32 + +func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { + return mlx.MulScalar(logits, 1.0/float32(t)) +} + +// TopKSampler masks all but the top-k logits. +type TopKSampler int + +func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array { + neg := mlx.Negative(logits) + mask := mlx.Argpartition(neg, int(k)-1, -1) + // Slice the indices beyond top-k + mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1))) + return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1) +} + +// TopP implements nucleus sampling (cumulative probability threshold). +type TopP float32 + +func (p TopP) Sample(logits *mlx.Array) *mlx.Array { + // TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly. + // For now, pass through. TopK + Temperature covers most use cases. + return logits +} + +// MinPSampler masks tokens below min_p * max_prob. +type MinPSampler float32 + +func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array { + // For now, pass through — MinP is an optimization over TopP. + // Full implementation requires finding max prob and masking below threshold. + return logits +} diff --git a/slice.go b/slice.go new file mode 100644 index 0000000..5bb7a66 --- /dev/null +++ b/slice.go @@ -0,0 +1,63 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" +*/ +import "C" + +// Slice extracts a sub-array using start and end indices for each dimension. +// starts and ends must have the same length as the array's dimensions. +func Slice(a *Array, starts, ends []int32) *Array { + out := New("SLICE", a) + cStarts := make([]C.int, len(starts)) + cEnds := make([]C.int, len(ends)) + for i := range starts { + cStarts[i] = C.int(starts[i]) + cEnds[i] = C.int(ends[i]) + } + strides := make([]C.int, len(starts)) + for i := range strides { + strides[i] = 1 + } + C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) + return out +} + +// SliceAxis extracts a sub-array along a single axis. +func SliceAxis(a *Array, axis int, start, end int32) *Array { + // Build full slice parameters + ndim := a.NumDims() + starts := make([]int32, ndim) + ends := make([]int32, ndim) + for i := 0; i < ndim; i++ { + starts[i] = 0 + ends[i] = int32(a.Dim(i)) + } + ax := axis + if ax < 0 { + ax = ndim + ax + } + starts[ax] = start + ends[ax] = end + return Slice(a, starts, ends) +} + +// SliceUpdateInplace updates a slice of the array in-place. +// This is critical for KV cache updates. +func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array { + out := New("SLICE_UPDATE", a, update) + cStarts := make([]C.int, len(starts)) + cEnds := make([]C.int, len(ends)) + for i := range starts { + cStarts[i] = C.int(starts[i]) + cEnds[i] = C.int(ends[i]) + } + strides := make([]C.int, len(starts)) + for i := range strides { + strides[i] = 1 + } + C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) + return out +} diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..248a224 --- /dev/null +++ b/stream.go @@ -0,0 +1,79 @@ +//go:build darwin && arm64 + +package mlx + +/* +#include "mlx/c/mlx.h" +*/ +import "C" + +import "sync" + +// Stream wraps an mlx_stream handle for dispatching operations. +type Stream struct { + ctx C.mlx_stream +} + +var ( + defaultStream *Stream + defaultStreamOnce sync.Once +) + +// DefaultStream returns the default GPU stream, creating it on first use. +func DefaultStream() *Stream { + defaultStreamOnce.Do(func() { + Init() + defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()} + }) + return defaultStream +} + +// DefaultGPUStream returns a new GPU stream. +func DefaultGPUStream() *Stream { + Init() + return &Stream{ctx: C.mlx_default_gpu_stream_new()} +} + +// DefaultCPUStream returns a new CPU stream. +func DefaultCPUStream() *Stream { + Init() + return &Stream{ctx: C.mlx_default_cpu_stream_new()} +} + +// Synchronize waits for all operations on the stream to complete. +func Synchronize(s *Stream) { + C.mlx_synchronize(s.ctx) +} + +// SetMemoryLimit sets the Metal memory limit. Returns the previous limit. +func SetMemoryLimit(limit uint64) uint64 { + var prev C.size_t + C.mlx_set_memory_limit(&prev, C.size_t(limit)) + return uint64(prev) +} + +// SetCacheLimit sets the Metal cache limit. Returns the previous limit. +func SetCacheLimit(limit uint64) uint64 { + var prev C.size_t + C.mlx_set_cache_limit(&prev, C.size_t(limit)) + return uint64(prev) +} + +// GetActiveMemory returns the current Metal memory usage in bytes. +func GetActiveMemory() uint64 { + var mem C.size_t + C.mlx_get_active_memory(&mem) + return uint64(mem) +} + +// GetPeakMemory returns the peak Metal memory usage in bytes. +func GetPeakMemory() uint64 { + var mem C.size_t + C.mlx_get_peak_memory(&mem) + return uint64(mem) +} + +// ClearCache releases Metal memory held in the MLX allocator cache. +func ClearCache() { + C.mlx_clear_cache() +} diff --git a/tokenizer/tokenizer.go b/tokenizer/tokenizer.go new file mode 100644 index 0000000..947a1a5 --- /dev/null +++ b/tokenizer/tokenizer.go @@ -0,0 +1,324 @@ +//go:build darwin && arm64 + +// Package tokenizer provides BPE tokenization for transformer models. +package tokenizer + +import ( + "encoding/json" + "fmt" + "os" + "strings" +) + +// Tokenizer handles text-to-token and token-to-text conversion. +type Tokenizer struct { + vocab map[string]int32 + invVocab map[int32]string + merges []mergePair + special map[string]int32 + + bosToken int32 + eosToken int32 + + // GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.) + isGPT2BPE bool + gpt2Decoder map[rune]byte // Unicode char → original byte + gpt2Encoder map[byte]rune // original byte → Unicode char +} + +type mergePair struct { + a, b string + rank int +} + +// tokenizerJSON is the HuggingFace tokenizer.json format. +type tokenizerJSON struct { + Model struct { + Type string `json:"type"` + Vocab json.RawMessage `json:"vocab"` + Merges json.RawMessage `json:"merges"` + ByteFallback bool `json:"byte_fallback"` + } `json:"model"` + AddedTokens []struct { + ID int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` +} + +// Load reads a tokenizer.json file and creates a Tokenizer. +func Load(path string) (*Tokenizer, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("tokenizer: read %s: %w", path, err) + } + + var tj tokenizerJSON + if err := json.Unmarshal(data, &tj); err != nil { + return nil, fmt.Errorf("tokenizer: parse: %w", err) + } + + t := &Tokenizer{ + vocab: make(map[string]int32), + invVocab: make(map[int32]string), + special: make(map[string]int32), + } + + // Parse vocab + var vocab map[string]int32 + if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil { + return nil, fmt.Errorf("tokenizer: parse vocab: %w", err) + } + t.vocab = vocab + for k, v := range vocab { + t.invVocab[v] = k + } + + // Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats + if len(tj.Model.Merges) > 0 { + var stringMerges []string + if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil { + for rank, merge := range stringMerges { + parts := strings.SplitN(merge, " ", 2) + if len(parts) == 2 { + t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank}) + } + } + } else { + var arrayMerges [][]string + if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil { + for rank, pair := range arrayMerges { + if len(pair) == 2 { + t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank}) + } + } + } + } + } + + // Parse special tokens + for _, tok := range tj.AddedTokens { + if tok.Special { + t.special[tok.Content] = tok.ID + } + t.vocab[tok.Content] = tok.ID + t.invVocab[tok.ID] = tok.Content + } + + // Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space) + if _, ok := t.vocab["Ġ"]; ok { + t.isGPT2BPE = true + t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps() + } + + // Set BOS/EOS — detect model family from special tokens + if id, ok := t.special[""]; ok { + t.bosToken = id + } + if id, ok := t.special[""]; ok { + t.eosToken = id + } + // Gemma: is the generation stop token + if id, ok := t.special[""]; ok { + t.eosToken = id + } + // Qwen3: <|im_end|> is the generation stop token + if id, ok := t.special["<|im_end|>"]; ok { + t.eosToken = id + } + // Qwen3 BOS: <|im_start|> + if id, ok := t.special["<|im_start|>"]; ok { + t.bosToken = id + } + + return t, nil +} + +// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps. +// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars +// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves; +// everything else (0-32, 127-160, 173) maps to U+0100 onwards. +func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) { + encoder = make(map[byte]rune, 256) + decoder = make(map[rune]byte, 256) + + // Self-mapping ranges: printable ASCII + Latin-1 Supplement + // Use int loop variable to avoid byte overflow at 255. + selfMap := func(lo, hi int) { + for b := lo; b <= hi; b++ { + encoder[byte(b)] = rune(b) + decoder[rune(b)] = byte(b) + } + } + selfMap(33, 126) // ! through ~ + selfMap(161, 172) // ¡ through ¬ + selfMap(174, 255) // ® through ÿ + + // Non-self-mapping: control chars, space, DEL, and gaps + n := 0 + for b := 0; b < 256; b++ { + if _, ok := encoder[byte(b)]; !ok { + r := rune(256 + n) + encoder[byte(b)] = r + decoder[r] = byte(b) + n++ + } + } + return +} + +// Encode converts text to token IDs. Prepends BOS token. +func (t *Tokenizer) Encode(text string) []int32 { + tokens := []int32{t.bosToken} + + if t.isGPT2BPE { + return t.encodeGPT2(text) + } + + // SentencePiece style encoding + remaining := text + for remaining != "" { + found := false + for tok, id := range t.special { + if strings.HasPrefix(remaining, tok) { + tokens = append(tokens, id) + remaining = remaining[len(tok):] + found = true + break + } + } + if !found { + r := []rune(remaining) + ch := "▁" + string(r[0]) + if id, ok := t.vocab[ch]; ok { + tokens = append(tokens, id) + } else if id, ok := t.vocab[string(r[0])]; ok { + tokens = append(tokens, id) + } + remaining = string(r[1:]) + } + } + + return tokens +} + +// encodeGPT2 encodes text using GPT-2 byte-level BPE. +func (t *Tokenizer) encodeGPT2(text string) []int32 { + tokens := []int32{t.bosToken} + + // Convert text bytes to GPT-2 Unicode representation + var encoded strings.Builder + for _, b := range []byte(text) { + if r, ok := t.gpt2Encoder[b]; ok { + encoded.WriteRune(r) + } + } + gpt2Text := encoded.String() + + // Scan for special tokens and regular text + remaining := gpt2Text + for remaining != "" { + // Check special tokens (these are stored as-is, not byte-encoded) + found := false + for tok, id := range t.special { + // Special tokens in GPT-2 tokenizers are stored in their original form + // Convert the special token to GPT-2 encoding for matching + var encTok strings.Builder + for _, b := range []byte(tok) { + if r, ok := t.gpt2Encoder[b]; ok { + encTok.WriteRune(r) + } + } + encStr := encTok.String() + if strings.HasPrefix(remaining, encStr) { + tokens = append(tokens, id) + remaining = remaining[len(encStr):] + found = true + break + } + } + if !found { + // Character-by-character lookup (simplified BPE) + r := []rune(remaining) + ch := string(r[0]) + if id, ok := t.vocab[ch]; ok { + tokens = append(tokens, id) + } + remaining = string(r[1:]) + } + } + + return tokens +} + +// Decode converts token IDs back to text. +// For full-sequence decoding, the SentencePiece leading space is stripped. +func (t *Tokenizer) Decode(tokens []int32) string { + var sb strings.Builder + for _, id := range tokens { + if text, ok := t.invVocab[id]; ok { + // Skip special tokens in decode output + if _, isSpecial := t.special[text]; isSpecial { + continue + } + sb.WriteString(text) + } + } + raw := sb.String() + + if t.isGPT2BPE { + return t.decodeGPT2Bytes(raw) + } + + // SentencePiece style + result := strings.ReplaceAll(raw, "▁", " ") + if strings.HasPrefix(result, " ") { + result = result[1:] + } + return result +} + +// DecodeToken converts a single token ID to text for streaming. +// Unlike Decode, it preserves the leading space (word boundary) so that +// token-by-token output maintains correct spacing between words. +func (t *Tokenizer) DecodeToken(id int32) string { + text, ok := t.invVocab[id] + if !ok { + return "" + } + if _, isSpecial := t.special[text]; isSpecial { + return "" + } + + if t.isGPT2BPE { + return t.decodeGPT2Bytes(text) + } + + // SentencePiece: replace ▁ with space but keep it (it's the word boundary) + return strings.ReplaceAll(text, "▁", " ") +} + +// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. +func (t *Tokenizer) decodeGPT2Bytes(s string) string { + var buf []byte + for _, r := range s { + if b, ok := t.gpt2Decoder[r]; ok { + buf = append(buf, b) + } else { + // Non-mapped runes pass through as UTF-8 + buf = append(buf, []byte(string(r))...) + } + } + return string(buf) +} + +// BOSToken returns the beginning-of-sequence token ID. +func (t *Tokenizer) BOSToken() int32 { return t.bosToken } + +// EOSToken returns the end-of-sequence token ID. +func (t *Tokenizer) EOSToken() int32 { return t.eosToken } + +// FormatGemmaPrompt applies the Gemma 3 chat template. +func FormatGemmaPrompt(prompt string) string { + return fmt.Sprintf("user\n%s\nmodel\n", prompt) +}