diff --git a/FINDINGS.md b/FINDINGS.md index e282b49..a362907 100644 --- a/FINDINGS.md +++ b/FINDINGS.md @@ -114,3 +114,66 @@ Tested on Mac Studio M3 Ultra (32-core CPU, 60-core GPU, 96GB unified memory): Available on `/Volumes/Data/lem/safetensors/`: - Gemma3-1B, Gemma3-4B, Gemma3-27B - Qwen3-8B (used by LEM Lab) + +--- + +## 2026-02-19: Go 1.26 Impact Assessment + +Source: https://go.dev/doc/go1.26 + +### High Impact (free performance, no code changes) + +**CGO call overhead reduced ~30%** +Every MLX operation (MatMul, Add, Softmax, RoPE, etc.) crosses the CGO boundary. The runtime previously used a dedicated syscall P state for cgo calls; Go 1.26 removes that and checks goroutine status instead. This is a direct, automatic performance win for the entire package. + +**Green Tea GC now default (10-40% less GC overhead)** +Critical for go-mlx because `Array` objects use `runtime.SetFinalizer` for C-side deallocation via `mlx_*_free()`. Reduced GC overhead means: +- More timely finaliser execution during sustained inference +- Less memory pressure from stale Array objects waiting for GC +- The FINDINGS.md concern about "GC not keeping up under high throughput" is partially mitigated +- Opt-out: `GOEXPERIMENT=nogreenteagc` (temporary, removed in 1.27) + +### Medium Impact + +**Slice stack allocation in more situations** +The compiler can now allocate slice backing stores on the stack more often. Benefits small temporary slices in `Collect()`, shape manipulation, and internal ops helpers. Debug: `-compile=variablemakehash` flag. + +**`testing.B.Loop` inlining fix** +When we add benchmarks (Phase 1), `b.Loop()` style now properly inlines loop bodies. Important for micro-benchmarks of small ops like Add, Multiply. + +**Heap base address randomisation (64-bit)** +Security improvement for CGO programs. Randomises heap base at startup. Disable: `GOEXPERIMENT=norandomizedheapbase64`. + +### Clarification on Range-over-func + +Virgil's Phase 6 TODO mentions "if 1.26 stabilises range-over-func". **Range-over-func has been stable since Go 1.23** and the `iter` package was added in 1.23. Since go.mod is already at Go 1.25.5, `Array.Iter() iter.Seq[float32]` can be implemented today without a version bump. Go 1.26 adds no new iterator features beyond what 1.23-1.25 provide. + +### Recommendation + +No Go version bump needed for the performance wins — they're automatic at runtime. The only code-level Go 1.26 feature that matters is `testing.ArtifactDir()` for benchmark result storage, which is minor. Focus remains on Phase 1 hardening. + +--- + +## 2026-02-19: go-ai Split Context + +Virgil is splitting go-ai into sub-packages, with go-ai becoming a meta/catch-all for ML features. go-mlx was the first extraction. This means: +- More packages will follow the go-mlx pattern (standalone module, own build, own tests) +- go-ai will eventually be a thin layer importing sub-packages +- The `replace` directive approach works for development; published modules for releases + +--- + +## 2026-02-19: Floats()/DataInt32() Unsafe on Non-Contiguous Arrays + +**Discovery**: `Array.Floats()` and `Array.DataInt32()` read `Size()` elements from the raw C data pointer (`mlx_array_data_float32`). For non-contiguous arrays (transpose, broadcast, slice views), the physical memory layout doesn't match the logical layout. Reading `Size()` contiguous elements returns incorrect data or reads past the physical buffer. + +**Affected operations**: `Transpose()`, `BroadcastTo()`, `SliceAxis()`, `Slice()`, `AsStrided()` — any operation that creates a view rather than a copy. + +**Workaround**: `Reshape(arr, totalSize)` forces a contiguous copy before reading flat data. All tests use this pattern for view operations. + +**Fix needed (Phase 4)**: Either: +1. Add a `Contiguous()` method that wraps `mlx_contiguous` (if available in mlx-c) +2. Or have `Floats()`/`DataInt32()` automatically force contiguity before reading +3. Document the behaviour clearly if views are intentionally lazy + +This is a data correctness issue — silent wrong results, not a crash. diff --git a/TODO.md b/TODO.md index 96b1188..5e4b322 100644 --- a/TODO.md +++ b/TODO.md @@ -6,8 +6,8 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 1: Standalone Package Hardening -- [ ] **Verify go generate → test round-trip** — Run `go generate ./...` to build mlx-c, then `go test ./...`. Confirm all 3 test files (grad, lora, optim) pass on Apple Silicon. Document any CMake version requirements. -- [ ] **Add missing tests for core operations** — `ops.go` (353 LOC), `array.go` (261 LOC), `nn.go`, `compile.go`, `fast.go` have zero tests. Priority: ops (MatMul, Softmax, Add) and array (create, reshape, data access). +- [x] **Verify go generate → test round-trip** — ✅ 29/29 tests pass. CMake 3.24+, AppleClang 17.0.0, macOS SDK 26.2. Build takes ~2min on M3 Ultra. +- [x] **Add missing tests for core operations** — ✅ 86 new tests across 4 files: array_test.go (25), ops_test.go (44), nn_test.go (8), fast_test.go (9). Covers: all scalar/array creation, shape ops, element-wise arithmetic, math functions, matrix ops, reductions, indexing, slicing, fused kernels (RMSNorm, LayerNorm, RoPE, SDPA), Linear, Embedding, RepeatKV. Found non-contiguous view bug in Floats()/DataInt32() — see FINDINGS.md. - [ ] **Add missing tests for model/tokenizer/sample/cache** — `model/`, `tokenizer/`, `sample/`, `cache/` have zero test files. Priority: tokenizer (BPE round-trip), sample (temperature/top-k), cache (KV append/trim). - [ ] **Benchmark suite** — No benchmarks exist. Add: MatMul (various sizes), Softmax, model.Forward (single token), tokenizer.Encode/Decode, full Generate (tokens/sec). Baseline on M3 Ultra. @@ -39,7 +39,7 @@ Dispatched from core/go orchestration. Pick up tasks in order. ## Phase 6: Go 1.26 Modernisation -- [ ] **Evaluate Go 1.26 features** — Check what's new in 1.26 that benefits this package. Candidates: `iter` package improvements, range-over-func patterns for array iteration, any new `unsafe` changes that affect CGO. Document findings in FINDINGS.md with benchmarks showing whether it's worth the version bump. +- [x] **Evaluate Go 1.26 features** — ✅ Documented in FINDINGS.md. Key wins: CGO ~30% faster (free), Green Tea GC default (10-40% less overhead, helps Array finalisers), slice stack alloc. Range-over-func already stable since 1.23. - [ ] **Range-over-func for Array** — If 1.26 stabilises range-over-func, `Array.Iter()` returning `iter.Seq[float32]` (or typed variant) would be cleaner than the current index-based access. Measure overhead vs direct C pointer access. --- diff --git a/array_test.go b/array_test.go new file mode 100644 index 0000000..397c811 --- /dev/null +++ b/array_test.go @@ -0,0 +1,370 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +// --- Scalar creation (FromValue) --- + +func TestFromValue_Float32(t *testing.T) { + a := FromValue(float32(3.14)) + Materialize(a) + + if a.Dtype() != DTypeFloat32 { + t.Errorf("dtype = %v, want float32", a.Dtype()) + } + if a.NumDims() != 0 { + t.Errorf("ndim = %d, want 0 (scalar)", a.NumDims()) + } + if a.Size() != 1 { + t.Errorf("size = %d, want 1", a.Size()) + } + if math.Abs(a.Float()-3.14) > 1e-5 { + t.Errorf("value = %f, want 3.14", a.Float()) + } +} + +func TestFromValue_Float64(t *testing.T) { + a := FromValue(float64(2.718281828)) + Materialize(a) + + if a.Dtype() != DTypeFloat64 { + t.Errorf("dtype = %v, want float64", a.Dtype()) + } + if math.Abs(a.Float()-2.718281828) > 1e-8 { + t.Errorf("value = %f, want 2.718281828", a.Float()) + } +} + +func TestFromValue_Int(t *testing.T) { + a := FromValue(42) + Materialize(a) + + if a.Dtype() != DTypeInt32 { + t.Errorf("dtype = %v, want int32", a.Dtype()) + } + if a.Int() != 42 { + t.Errorf("value = %d, want 42", a.Int()) + } +} + +func TestFromValue_Bool(t *testing.T) { + a := FromValue(true) + Materialize(a) + + if a.Dtype() != DTypeBool { + t.Errorf("dtype = %v, want bool", a.Dtype()) + } + if a.Int() != 1 { + t.Errorf("value = %d, want 1 (true)", a.Int()) + } +} + +func TestFromValue_Complex64(t *testing.T) { + a := FromValue(complex64(3 + 4i)) + Materialize(a) + + if a.Dtype() != DTypeComplex64 { + t.Errorf("dtype = %v, want complex64", a.Dtype()) + } + if a.Size() != 1 { + t.Errorf("size = %d, want 1", a.Size()) + } +} + +// --- Slice creation (FromValues) --- + +func TestFromValues_Float32_1D(t *testing.T) { + data := []float32{1.0, 2.0, 3.0, 4.0} + a := FromValues(data, 4) + Materialize(a) + + if a.Dtype() != DTypeFloat32 { + t.Errorf("dtype = %v, want float32", a.Dtype()) + } + if a.NumDims() != 1 { + t.Errorf("ndim = %d, want 1", a.NumDims()) + } + if a.Dim(0) != 4 { + t.Errorf("dim(0) = %d, want 4", a.Dim(0)) + } + if a.Size() != 4 { + t.Errorf("size = %d, want 4", a.Size()) + } + + got := a.Floats() + for i, want := range data { + if math.Abs(float64(got[i]-want)) > 1e-6 { + t.Errorf("element[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestFromValues_Float32_2D(t *testing.T) { + data := []float32{1, 2, 3, 4, 5, 6} + a := FromValues(data, 2, 3) // 2x3 matrix + Materialize(a) + + if a.NumDims() != 2 { + t.Errorf("ndim = %d, want 2", a.NumDims()) + } + shape := a.Shape() + if shape[0] != 2 || shape[1] != 3 { + t.Errorf("shape = %v, want [2 3]", shape) + } + if a.Size() != 6 { + t.Errorf("size = %d, want 6", a.Size()) + } + + got := a.Floats() + for i, want := range data { + if math.Abs(float64(got[i]-want)) > 1e-6 { + t.Errorf("element[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestFromValues_Int32(t *testing.T) { + data := []int32{10, 20, 30} + a := FromValues(data, 3) + Materialize(a) + + if a.Dtype() != DTypeInt32 { + t.Errorf("dtype = %v, want int32", a.Dtype()) + } + got := a.DataInt32() + for i, want := range data { + if got[i] != want { + t.Errorf("element[%d] = %d, want %d", i, got[i], want) + } + } +} + +func TestFromValues_Int64(t *testing.T) { + data := []int64{100, 200, 300} + a := FromValues(data, 3) + Materialize(a) + + if a.Dtype() != DTypeInt64 { + t.Errorf("dtype = %v, want int64", a.Dtype()) + } + if a.Size() != 3 { + t.Errorf("size = %d, want 3", a.Size()) + } +} + +func TestFromValues_Bool(t *testing.T) { + data := []bool{true, false, true} + a := FromValues(data, 3) + Materialize(a) + + if a.Dtype() != DTypeBool { + t.Errorf("dtype = %v, want bool", a.Dtype()) + } + if a.Size() != 3 { + t.Errorf("size = %d, want 3", a.Size()) + } +} + +func TestFromValues_Uint8(t *testing.T) { + data := []uint8{0, 127, 255} + a := FromValues(data, 3) + Materialize(a) + + if a.Dtype() != DTypeUint8 { + t.Errorf("dtype = %v, want uint8", a.Dtype()) + } +} + +func TestFromValues_PanicsWithoutShape(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("expected panic when shape is missing") + } + }() + FromValues([]float32{1, 2, 3}) +} + +// --- Zeros --- + +func TestZeros(t *testing.T) { + a := Zeros([]int32{2, 3}, DTypeFloat32) + Materialize(a) + + if a.Dtype() != DTypeFloat32 { + t.Errorf("dtype = %v, want float32", a.Dtype()) + } + shape := a.Shape() + if shape[0] != 2 || shape[1] != 3 { + t.Errorf("shape = %v, want [2 3]", shape) + } + if a.Size() != 6 { + t.Errorf("size = %d, want 6", a.Size()) + } + + for i, v := range a.Floats() { + if v != 0.0 { + t.Errorf("element[%d] = %f, want 0.0", i, v) + } + } +} + +func TestZeros_Int32(t *testing.T) { + a := Zeros([]int32{4}, DTypeInt32) + Materialize(a) + + if a.Dtype() != DTypeInt32 { + t.Errorf("dtype = %v, want int32", a.Dtype()) + } + for i, v := range a.DataInt32() { + if v != 0 { + t.Errorf("element[%d] = %d, want 0", i, v) + } + } +} + +// --- Shape and metadata --- + +func TestArray_Shape3D(t *testing.T) { + data := make([]float32, 24) + a := FromValues(data, 2, 3, 4) + Materialize(a) + + if a.NumDims() != 3 { + t.Errorf("ndim = %d, want 3", a.NumDims()) + } + dims := a.Dims() + if dims[0] != 2 || dims[1] != 3 || dims[2] != 4 { + t.Errorf("dims = %v, want [2 3 4]", dims) + } + if a.Size() != 24 { + t.Errorf("size = %d, want 24", a.Size()) + } + if a.NumBytes() != 24*4 { // float32 = 4 bytes + t.Errorf("nbytes = %d, want %d", a.NumBytes(), 24*4) + } +} + +// --- String representation --- + +func TestArray_String(t *testing.T) { + a := FromValue(float32(42.0)) + Materialize(a) + + s := a.String() + if s == "" { + t.Error("String() returned empty") + } + // MLX prints "array(42, dtype=float32)" or similar + t.Logf("String() = %q", s) +} + +// --- Clone and Set --- + +func TestArray_Clone(t *testing.T) { + a := FromValue(float32(7.0)) + b := a.Clone() + Materialize(a, b) + + if math.Abs(b.Float()-7.0) > 1e-6 { + t.Errorf("clone value = %f, want 7.0", b.Float()) + } +} + +func TestArray_Set(t *testing.T) { + a := FromValue(float32(1.0)) + b := FromValue(float32(2.0)) + Materialize(a, b) + + a.Set(b) + Materialize(a) + + if math.Abs(a.Float()-2.0) > 1e-6 { + t.Errorf("after Set, value = %f, want 2.0", a.Float()) + } +} + +// --- Valid and Free --- + +func TestArray_Valid(t *testing.T) { + a := FromValue(float32(1.0)) + Materialize(a) + + if !a.Valid() { + t.Error("expected Valid() = true for live array") + } + + Free(a) + if a.Valid() { + t.Error("expected Valid() = false after Free") + } +} + +func TestFree_ReturnsBytes(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 4) + Materialize(a) + + n := Free(a) + if n != 16 { // 4 * float32(4 bytes) + t.Errorf("Free returned %d bytes, want 16", n) + } +} + +func TestFree_NilSafe(t *testing.T) { + // Should not panic on nil + n := Free(nil) + if n != 0 { + t.Errorf("Free(nil) returned %d, want 0", n) + } +} + +// --- Data extraction edge cases --- + +func TestArray_Ints(t *testing.T) { + data := []int32{10, 20, 30, 40} + a := FromValues(data, 4) + Materialize(a) + + got := a.Ints() + for i, want := range []int{10, 20, 30, 40} { + if got[i] != want { + t.Errorf("Ints()[%d] = %d, want %d", i, got[i], want) + } + } +} + +func TestArray_Float_DTypeFloat32(t *testing.T) { + a := FromValue(float32(1.5)) + Materialize(a) + + got := a.Float() + if math.Abs(got-1.5) > 1e-6 { + t.Errorf("Float() = %f, want 1.5", got) + } +} + +func TestArray_Float_DTypeFloat64(t *testing.T) { + a := FromValue(float64(1.5)) + Materialize(a) + + got := a.Float() + if math.Abs(got-1.5) > 1e-12 { + t.Errorf("Float() = %f, want 1.5", got) + } +} + +// --- Collect --- + +func TestCollect_FiltersNilAndInvalid(t *testing.T) { + a := FromValue(float32(1.0)) + b := FromValue(float32(2.0)) + Materialize(a, b) + + result := Collect(a, nil, b) + if len(result) != 2 { + t.Errorf("Collect returned %d arrays, want 2", len(result)) + } +} diff --git a/cpp/CLAUDE.md b/cpp/CLAUDE.md new file mode 100644 index 0000000..35aa892 --- /dev/null +++ b/cpp/CLAUDE.md @@ -0,0 +1,107 @@ +# CLAUDE.md — go-mlx C++ / mlx-c Side + +## What This Is + +You are the C++ specialist for `forge.lthn.ai/core/go-mlx`. This Go package wraps Apple's MLX framework through the **mlx-c** C API using CGO. You handle C/C++ work; a separate GoLand Claude handles the Go bindings. + +## Your Role + +- Inspect and understand the mlx-c API surface (headers, source) +- When Virgil (the core/go orchestration agent) or the GoLand Claude needs a C-level change, you implement it +- Debug C-level issues (segfaults, memory leaks, API mismatches) +- Research mlx-c capabilities for new Go bindings +- You do NOT modify Go files — relay Go-side changes to the GoLand Claude + +## Project Layout + +``` +go-mlx/ ← CMakeLists.txt lives here (CLion project root) +├── CMakeLists.txt ← Fetches mlx-c v0.4.1, builds to dist/ +├── cpp/ ← YOUR workspace docs (this directory) +│ ├── CLAUDE.md ← This file +│ ├── TODO.md ← C++ task queue +│ └── FINDINGS.md ← C++ research notes +│ +├── build/ ← CMake build directory (gitignored) +│ └── _deps/ +│ ├── mlx-c-src/ ← mlx-c v0.4.1 source (24 .cpp, 27 .h) +│ │ └── mlx/c/ ← The C API headers + implementations +│ └── mlx-src/ ← MLX C++ framework source (Metal shaders, ops) +│ +├── dist/ ← Installed artifacts (gitignored) +│ ├── include/mlx/c/ ← Headers used by Go CGO +│ └── lib/ ← libmlxc.dylib, libmlx.dylib +│ +└── *.go ← Go bindings (GoLand Claude's domain) +``` + +## Key Files + +### mlx-c Headers (the API surface Go binds to) +- `build/_deps/mlx-c-src/mlx/c/array.h` — Array creation, data access, shape +- `build/_deps/mlx-c-src/mlx/c/ops.h` — Element-wise and reduction operations +- `build/_deps/mlx-c-src/mlx/c/fast.h` — Fused Metal kernels (RMSNorm, RoPE, SDPA) +- `build/_deps/mlx-c-src/mlx/c/transforms.h` — eval, compile, VJP/JVP +- `build/_deps/mlx-c-src/mlx/c/io.h` — Safetensors load/save +- `build/_deps/mlx-c-src/mlx/c/random.h` — Random number generation +- `build/_deps/mlx-c-src/mlx/c/stream.h` — Metal stream/queue +- `build/_deps/mlx-c-src/mlx/c/metal.h` — Metal device availability +- `build/_deps/mlx-c-src/mlx/c/error.h` — Error handler registration +- `build/_deps/mlx-c-src/mlx/c/memory.h` — Memory management + +### CMake +- `go-mlx/CMakeLists.txt` — Top-level, fetches mlx-c v0.4.1 from GitHub +- `build/_deps/mlx-c-src/CMakeLists.txt` — mlx-c's own build config + +## Build + +```bash +# From go-mlx root: +cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release +cmake --build build --parallel +cmake --install build + +# Or via Go: +go generate ./... +``` + +### CMake Settings +- `MLX_BUILD_SAFETENSORS=ON` +- `MLX_BUILD_GGUF=OFF` +- `BUILD_SHARED_LIBS=ON` +- macOS deployment target: 26.0 + +## How Go Binds to mlx-c + +Go files use CGO with `#include "mlx/c/mlx.h"` and link via: +``` +#cgo CPPFLAGS: -I${SRCDIR}/dist/include +#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx +#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate +``` + +Each Go function calls the corresponding `mlx_*` C function. The pattern is: +1. Go allocates a C output handle (`mlx_array_new()`) +2. Calls the C operation (`mlx_add()`, `mlx_matmul()`, etc.) +3. Wraps result in Go `*Array` with `runtime.SetFinalizer` → `mlx_array_free()` + +## Communication Protocol + +- **Receiving work**: Tasks appear in `cpp/TODO.md`, written by the GoLand Claude or Virgil +- **Reporting findings**: Write to `cpp/FINDINGS.md` +- **Requesting Go changes**: Describe what the GoLand Claude needs to change in FINDINGS.md +- **Completing tasks**: Mark `[x]` in TODO.md with notes + +## Platform + +- **darwin/arm64 only** — Apple Silicon (M1-M4) +- Tested on: Mac Studio M3 Ultra (32-core CPU, 60-core GPU, 96GB) +- Xcode Command Line Tools required +- CMake 3.24+ + +## Coding Standards + +- UK English in documentation +- Conventional commits: `type(scope): description` +- Co-Author: `Co-Authored-By: Virgil ` +- Licence: EUPL-1.2 diff --git a/cpp/FINDINGS.md b/cpp/FINDINGS.md new file mode 100644 index 0000000..d0898d4 --- /dev/null +++ b/cpp/FINDINGS.md @@ -0,0 +1,42 @@ +# FINDINGS.md — go-mlx C++ Research & Discovery + +Record findings about mlx-c internals, API quirks, and architectural decisions. + +--- + +## 2026-02-19: Initial Setup (GoLand Claude) + +### mlx-c v0.4.1 Source Stats +- **24 .cpp files**, **27 .h headers** in `build/_deps/mlx-c-src/mlx/c/` +- Pure C API wrapping C++ MLX framework +- Each header typically has: type definition, `_new()`, `_free()`, operation functions + +### Installed Artifacts (dist/) +- `dist/lib/libmlxc.dylib` — C API shared library +- `dist/lib/libmlx.dylib` — MLX C++ framework shared library +- `dist/include/mlx/c/` — 27 public headers + +### Build Verified +- CMake configure + build + install completes cleanly +- Go tests pass (29/29) after build +- AppleClang 17.0.0, macOS SDK 26.2 + +### Go Binding Coverage (preliminary) +The Go side currently uses functions from these headers: +- `array.h` — creation, data access, shape, dtype +- `ops.h` — add, multiply, divide, matmul, softmax, etc. +- `fast.h` — rms_norm, layer_norm, rope, scaled_dot_product_attention +- `transforms.h` — eval, compile, vjp, jvp +- `io.h` — safetensors load/save +- `random.h` — categorical +- `stream.h` — default_gpu_stream +- `metal.h` — is_available +- `error.h` — set_error_handler +- `vector.h` — vector_array operations +- `closure.h` — compiled function closures + +**Likely unused**: `fft.h`, `linalg.h`, `distributed.h`, `distributed_group.h`, `export.h`, `map.h`, `memory.h` + +--- + +*Add new findings below as work progresses.* diff --git a/cpp/TODO.md b/cpp/TODO.md new file mode 100644 index 0000000..4d35b12 --- /dev/null +++ b/cpp/TODO.md @@ -0,0 +1,24 @@ +# TODO.md — go-mlx C++ Task Queue + +Tasks for the CLion Claude session. Written by GoLand Claude or Virgil. + +--- + +## Orientation (First Session) + +- [ ] **Map the mlx-c API surface** — Read all 27 headers in `build/_deps/mlx-c-src/mlx/c/`. Document which functions the Go side currently binds (cross-reference with Go files) vs which are available but unused. Priority headers: `ops.h`, `fast.h`, `array.h`, `transforms.h`. +- [ ] **Understand the error model** — `error.h` provides `mlx_set_error_handler()`. The Go side registers a handler that logs to stderr. Research: can we get structured error info (error codes, categories)? Is the error string stable or does it vary? +- [ ] **Check memory management patterns** — `mlx_*_free()` functions exist for each type. Verify: is double-free safe? What happens if you free during async eval? Document for the Go finaliser integration. + +## Standing Tasks + +- [ ] **API gap analysis** — When the GoLand Claude needs a C function that isn't exposed by mlx-c, document the gap here and research if upstream mlx-c supports it or if a patch is needed. + +--- + +## Workflow + +1. GoLand Claude or Virgil writes tasks here +2. Pick up in order, mark `[x]` when done +3. New findings → `cpp/FINDINGS.md` +4. If Go changes needed → note in FINDINGS.md for GoLand Claude diff --git a/fast_test.go b/fast_test.go new file mode 100644 index 0000000..b896bce --- /dev/null +++ b/fast_test.go @@ -0,0 +1,181 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +func TestRMSNorm(t *testing.T) { + x := FromValues([]float32{1, 2, 3, 4}, 1, 4) + weight := FromValues([]float32{1, 1, 1, 1}, 4) + + y := RMSNorm(x, weight, 1e-5) + Materialize(y) + + got := y.Floats() + rms := math.Sqrt((1 + 4 + 9 + 16) / 4.0) + for i, val := range []float64{1, 2, 3, 4} { + want := val / rms + if math.Abs(float64(got[i])-want) > 1e-3 { + t.Errorf("RMSNorm[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestRMSNorm_WithScaling(t *testing.T) { + x := FromValues([]float32{1, 2, 3, 4}, 1, 4) + weight := FromValues([]float32{2, 2, 2, 2}, 4) + + y := RMSNorm(x, weight, 1e-5) + Materialize(y) + + got := y.Floats() + rms := math.Sqrt((1 + 4 + 9 + 16) / 4.0) + for i, val := range []float64{1, 2, 3, 4} { + want := 2.0 * val / rms + if math.Abs(float64(got[i])-want) > 1e-3 { + t.Errorf("RMSNorm scaled[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestLayerNorm(t *testing.T) { + x := FromValues([]float32{1, 2, 3, 4}, 1, 4) + weight := FromValues([]float32{1, 1, 1, 1}, 4) + bias := FromValues([]float32{0, 0, 0, 0}, 4) + + y := LayerNorm(x, weight, bias, 1e-5) + Materialize(y) + + got := y.Floats() + // Layer norm: mean=2.5, var=1.25, std≈1.118 + // Normalised: (x - mean) / std + mean := 2.5 + std := math.Sqrt(1.25) + for i, val := range []float64{1, 2, 3, 4} { + want := (val - mean) / std + if math.Abs(float64(got[i])-want) > 1e-3 { + t.Errorf("LayerNorm[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestLayerNorm_WithBias(t *testing.T) { + x := FromValues([]float32{1, 2, 3, 4}, 1, 4) + weight := FromValues([]float32{1, 1, 1, 1}, 4) + bias := FromValues([]float32{10, 10, 10, 10}, 4) + + y := LayerNorm(x, weight, bias, 1e-5) + Materialize(y) + + got := y.Floats() + // All values shifted by +10 + mean := 2.5 + std := math.Sqrt(1.25) + for i, val := range []float64{1, 2, 3, 4} { + want := (val-mean)/std + 10.0 + if math.Abs(float64(got[i])-want) > 1e-3 { + t.Errorf("LayerNorm+bias[%d] = %f, want %f", i, got[i], want) + } + } +} + +func TestRoPE(t *testing.T) { + // RoPE on a small input: [B=1, L=1, H=1, D=4] + x := FromValues([]float32{1, 0, 1, 0}, 1, 1, 1, 4) + y := RoPE(x, 4, false, 10000.0, 1.0, 0) + Materialize(y) + + shape := y.Shape() + if shape[0] != 1 || shape[1] != 1 || shape[2] != 1 || shape[3] != 4 { + t.Errorf("shape = %v, want [1 1 1 4]", shape) + } + + // At position 0, RoPE with offset 0 should be close to identity for cos(0)=1 + got := y.Floats() + // cos(0) = 1, sin(0) = 0, so rotation is identity at position 0 + if math.Abs(float64(got[0])-1.0) > 1e-3 { + t.Errorf("RoPE[0] = %f, want ≈1.0 (cos(0) rotation)", got[0]) + } +} + +func TestRoPE_ShapePreserved(t *testing.T) { + // Larger shape: [B=2, L=4, H=8, D=64] + data := make([]float32, 2*4*8*64) + for i := range data { + data[i] = 0.01 + } + x := FromValues(data, 2, 4, 8, 64) + y := RoPE(x, 64, false, 10000.0, 1.0, 0) + Materialize(y) + + shape := y.Shape() + if shape[0] != 2 || shape[1] != 4 || shape[2] != 8 || shape[3] != 64 { + t.Errorf("shape = %v, want [2 4 8 64]", shape) + } +} + +func TestScaledDotProductAttention_Causal(t *testing.T) { + // [B=1, H=1, L=3, D=2] + q := FromValues([]float32{1, 0, 0, 1, 1, 1}, 1, 1, 3, 2) + k := FromValues([]float32{1, 0, 0, 1, 1, 1}, 1, 1, 3, 2) + v := FromValues([]float32{1, 0, 0, 1, 0.5, 0.5}, 1, 1, 3, 2) + + scale := float32(1.0 / math.Sqrt(2.0)) + y := ScaledDotProductAttention(q, k, v, scale, true) + Materialize(y) + + shape := y.Shape() + if shape[0] != 1 || shape[1] != 1 || shape[2] != 3 || shape[3] != 2 { + t.Errorf("shape = %v, want [1 1 3 2]", shape) + } + + // First position can only attend to itself (causal) + flat := Reshape(y, 6) + Materialize(flat) + got := flat.Floats() + // Position 0 attends only to position 0: output = v[0] = [1, 0] + if math.Abs(float64(got[0])-1.0) > 1e-3 { + t.Errorf("SDPA causal pos0[0] = %f, want 1.0", got[0]) + } + if math.Abs(float64(got[1])-0.0) > 1e-3 { + t.Errorf("SDPA causal pos0[1] = %f, want 0.0", got[1]) + } +} + +func TestScaledDotProductAttention_NonCausal(t *testing.T) { + // Non-causal: all positions attend to all + q := FromValues([]float32{1, 0, 0, 1}, 1, 1, 2, 2) + k := FromValues([]float32{1, 0, 0, 1}, 1, 1, 2, 2) + v := FromValues([]float32{10, 0, 0, 10}, 1, 1, 2, 2) + + scale := float32(1.0 / math.Sqrt(2.0)) + y := ScaledDotProductAttention(q, k, v, scale, false) + Materialize(y) + + shape := y.Shape() + if shape[0] != 1 || shape[1] != 1 || shape[2] != 2 || shape[3] != 2 { + t.Errorf("shape = %v, want [1 1 2 2]", shape) + } +} + +func TestScaledDotProductAttentionWithMask(t *testing.T) { + q := FromValues([]float32{1, 0, 0, 1}, 1, 1, 2, 2) + k := FromValues([]float32{1, 0, 0, 1}, 1, 1, 2, 2) + v := FromValues([]float32{10, 0, 0, 10}, 1, 1, 2, 2) + + // Mask: block second position from attending to first + // Large negative = -inf masking + mask := FromValues([]float32{0, 0, -1e9, 0}, 1, 1, 2, 2) + + scale := float32(1.0 / math.Sqrt(2.0)) + y := ScaledDotProductAttentionWithMask(q, k, v, mask, scale) + Materialize(y) + + shape := y.Shape() + if shape[0] != 1 || shape[1] != 1 || shape[2] != 2 || shape[3] != 2 { + t.Errorf("shape = %v, want [1 1 2 2]", shape) + } +} diff --git a/nn_test.go b/nn_test.go new file mode 100644 index 0000000..f7cbf21 --- /dev/null +++ b/nn_test.go @@ -0,0 +1,169 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +// --- Linear --- + +func TestLinear_Dense(t *testing.T) { + // y = x @ W.T + bias + // x: [1, 3], W: [2, 3], bias: [2] + // Result: [1, 2] + x := FromValues([]float32{1, 2, 3}, 1, 3) + w := FromValues([]float32{1, 0, 0, 0, 1, 0}, 2, 3) // identity-ish + bias := FromValues([]float32{10, 20}, 2) + + l := NewLinear(w, bias) + y := l.Forward(x) + Materialize(y) + + // x @ W.T = [1*1+2*0+3*0, 1*0+2*1+3*0] = [1, 2] + // + bias = [11, 22] + got := y.Floats() + if len(got) != 2 { + t.Fatalf("size = %d, want 2", len(got)) + } + if !approx(float64(got[0]), 11.0) { + t.Errorf("[0] = %f, want 11.0", got[0]) + } + if !approx(float64(got[1]), 22.0) { + t.Errorf("[1] = %f, want 22.0", got[1]) + } +} + +func TestLinear_NoBias(t *testing.T) { + x := FromValues([]float32{1, 2, 3}, 1, 3) + w := FromValues([]float32{1, 1, 1, 2, 2, 2}, 2, 3) + + l := NewLinear(w, nil) + y := l.Forward(x) + Materialize(y) + + // x @ W.T = [1+2+3, 2+4+6] = [6, 12] + got := y.Floats() + if !approx(float64(got[0]), 6.0) { + t.Errorf("[0] = %f, want 6.0", got[0]) + } + if !approx(float64(got[1]), 12.0) { + t.Errorf("[1] = %f, want 12.0", got[1]) + } +} + +func TestLinear_LoRARouting(t *testing.T) { + // When LoRA is attached, Forward should route through it + w := FromValues([]float32{1, 0, 0, 1}, 2, 2) + l := NewLinear(w, nil) + + lora := NewLoRALinear(l, 1, 1.0) + l.LoRA = lora + + x := FromValues([]float32{3, 4}, 1, 2) + y := l.Forward(x) + Materialize(y) + + // Should produce valid output (LoRA adds low-rank delta) + if y.Size() != 2 { + t.Errorf("size = %d, want 2", y.Size()) + } +} + +// --- Embedding --- + +func TestEmbedding_Forward(t *testing.T) { + // 4 tokens, 3-dim embeddings + w := FromValues([]float32{ + 0, 0, 0, // token 0 + 1, 1, 1, // token 1 + 2, 2, 2, // token 2 + 3, 3, 3, // token 3 + }, 4, 3) + + emb := &Embedding{Weight: w} + indices := FromValues([]int32{1, 3}, 2) + y := emb.Forward(indices) + Materialize(y) + + shape := y.Shape() + if shape[0] != 2 || shape[1] != 3 { + t.Errorf("shape = %v, want [2 3]", shape) + } + + flat := Reshape(y, 6) + Materialize(flat) + got := flat.Floats() + // token 1 = [1,1,1], token 3 = [3,3,3] + want := []float32{1, 1, 1, 3, 3, 3} + floatSliceApprox(t, got, want) +} + +func TestEmbedding_AsLinear(t *testing.T) { + w := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + emb := &Embedding{Weight: w} + l := emb.AsLinear() + + if l.Weight != w { + t.Error("AsLinear should share weight with embedding") + } +} + +// --- RMSNormModule --- + +func TestRMSNormModule_Forward(t *testing.T) { + x := FromValues([]float32{1, 2, 3, 4}, 1, 4) + weight := FromValues([]float32{1, 1, 1, 1}, 4) + + m := &RMSNormModule{Weight: weight} + y := m.Forward(x, 1e-5) + Materialize(y) + + // RMS norm normalises by RMS then scales by weight + got := y.Floats() + if len(got) != 4 { + t.Fatalf("size = %d, want 4", len(got)) + } + // RMS = sqrt(mean(x^2)) = sqrt((1+4+9+16)/4) = sqrt(7.5) ≈ 2.7386 + // Normalised: x / RMS ≈ [0.3651, 0.7303, 1.0954, 1.4606] + rms := math.Sqrt((1 + 4 + 9 + 16) / 4.0) + for i, x := range []float64{1, 2, 3, 4} { + want := x / rms + if math.Abs(float64(got[i])-want) > 1e-3 { + t.Errorf("[%d] = %f, want %f", i, got[i], want) + } + } +} + +// --- RepeatKV --- + +func TestRepeatKV_Factor1(t *testing.T) { + // factor=1 should return input unchanged + x := FromValues(make([]float32, 24), 1, 2, 3, 4) + y := RepeatKV(x, 1) + + if y != x { + t.Error("RepeatKV with factor=1 should return same pointer") + } +} + +func TestRepeatKV_Factor2(t *testing.T) { + // [B=1, H=2, L=1, D=2] with factor=2 -> [1, 4, 1, 2] + data := []float32{1, 2, 3, 4} + x := FromValues(data, 1, 2, 1, 2) + y := RepeatKV(x, 2) + Materialize(y) + + shape := y.Shape() + if shape[0] != 1 || shape[1] != 4 || shape[2] != 1 || shape[3] != 2 { + t.Errorf("shape = %v, want [1 4 1 2]", shape) + } + + flat := Reshape(y, 8) + Materialize(flat) + got := flat.Floats() + // Head 0 [1,2] repeated, Head 1 [3,4] repeated + want := []float32{1, 2, 1, 2, 3, 4, 3, 4} + floatSliceApprox(t, got, want) +} diff --git a/ops_test.go b/ops_test.go new file mode 100644 index 0000000..d0d2b8f --- /dev/null +++ b/ops_test.go @@ -0,0 +1,542 @@ +//go:build darwin && arm64 + +package mlx + +import ( + "math" + "testing" +) + +const tol = 1e-5 + +func approx(a, b float64) bool { return math.Abs(a-b) < tol } + +func floatSliceApprox(t *testing.T, got []float32, want []float32) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("length mismatch: got %d, want %d", len(got), len(want)) + } + for i := range got { + if !approx(float64(got[i]), float64(want[i])) { + t.Errorf("[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + +// --- Element-wise arithmetic --- + +func TestAdd(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 3) + b := FromValues([]float32{4, 5, 6}, 3) + c := Add(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{5, 7, 9}) +} + +func TestAddScalar(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 3) + c := AddScalar(a, 10.0) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{11, 12, 13}) +} + +func TestMul(t *testing.T) { + a := FromValues([]float32{2, 3, 4}, 3) + b := FromValues([]float32{5, 6, 7}, 3) + c := Mul(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{10, 18, 28}) +} + +func TestMulScalar(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 3) + c := MulScalar(a, 3.0) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{3, 6, 9}) +} + +func TestDivide(t *testing.T) { + a := FromValues([]float32{10, 20, 30}, 3) + b := FromValues([]float32{2, 5, 10}, 3) + c := Divide(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{5, 4, 3}) +} + +func TestSubtract(t *testing.T) { + a := FromValues([]float32{10, 20, 30}, 3) + b := FromValues([]float32{1, 2, 3}, 3) + c := Subtract(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{9, 18, 27}) +} + +func TestNegative(t *testing.T) { + a := FromValues([]float32{1, -2, 3}, 3) + c := Negative(a) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{-1, 2, -3}) +} + +// --- Math functions --- + +func TestExp(t *testing.T) { + a := FromValues([]float32{0, 1, 2}, 3) + c := Exp(a) + Materialize(c) + got := c.Floats() + for i, x := range []float32{0, 1, 2} { + want := float32(math.Exp(float64(x))) + if !approx(float64(got[i]), float64(want)) { + t.Errorf("Exp(%f) = %f, want %f", x, got[i], want) + } + } +} + +func TestSigmoid(t *testing.T) { + a := FromValues([]float32{0, 100, -100}, 3) + c := Sigmoid(a) + Materialize(c) + got := c.Floats() + // sigmoid(0)=0.5, sigmoid(large)≈1, sigmoid(-large)≈0 + if !approx(float64(got[0]), 0.5) { + t.Errorf("sigmoid(0) = %f, want 0.5", got[0]) + } + if got[1] < 0.999 { + t.Errorf("sigmoid(100) = %f, want ≈1.0", got[1]) + } + if got[2] > 0.001 { + t.Errorf("sigmoid(-100) = %f, want ≈0.0", got[2]) + } +} + +func TestSiLU(t *testing.T) { + // SiLU(x) = x * sigmoid(x) + a := FromValues([]float32{0, 1, -1}, 3) + c := SiLU(a) + Materialize(c) + got := c.Floats() + // SiLU(0) = 0*0.5 = 0 + if !approx(float64(got[0]), 0.0) { + t.Errorf("SiLU(0) = %f, want 0.0", got[0]) + } + // SiLU(1) = 1 * sigmoid(1) = 1/(1+exp(-1)) ≈ 0.731059 + want := 1.0 / (1.0 + math.Exp(-1.0)) + if math.Abs(float64(got[1])-want) > 1e-4 { + t.Errorf("SiLU(1) = %f, want %f", got[1], want) + } +} + +func TestTanh(t *testing.T) { + a := FromValues([]float32{0, 1, -1}, 3) + c := Tanh(a) + Materialize(c) + got := c.Floats() + for i, x := range []float32{0, 1, -1} { + want := float32(math.Tanh(float64(x))) + if !approx(float64(got[i]), float64(want)) { + t.Errorf("Tanh(%f) = %f, want %f", x, got[i], want) + } + } +} + +func TestSqrt(t *testing.T) { + a := FromValues([]float32{1, 4, 9, 16}, 4) + c := Sqrt(a) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 2, 3, 4}) +} + +func TestRsqrt(t *testing.T) { + a := FromValues([]float32{1, 4, 16}, 3) + c := Rsqrt(a) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1.0, 0.5, 0.25}) +} + +func TestReciprocal(t *testing.T) { + a := FromValues([]float32{1, 2, 4, 5}, 4) + c := Reciprocal(a) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1.0, 0.5, 0.25, 0.2}) +} + +func TestSquare(t *testing.T) { + a := FromValues([]float32{1, 2, 3, -4}, 4) + c := Square(a) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 4, 9, 16}) +} + +func TestPower(t *testing.T) { + a := FromValues([]float32{2, 3, 4}, 3) + b := FromValues([]float32{3, 2, 0.5}, 3) + c := Power(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{8, 9, 2}) +} + +func TestMaximum(t *testing.T) { + a := FromValues([]float32{1, 5, 3}, 3) + b := FromValues([]float32{4, 2, 6}, 3) + c := Maximum(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{4, 5, 6}) +} + +func TestMinimum(t *testing.T) { + a := FromValues([]float32{1, 5, 3}, 3) + b := FromValues([]float32{4, 2, 6}, 3) + c := Minimum(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 2, 3}) +} + +// --- Matrix operations --- + +func TestMatmul(t *testing.T) { + // [1 2] @ [5 6]T = [1*5+2*7, 1*6+2*8] = [19, 22] + // [3 4] [7 8] [3*5+4*7, 3*6+4*8] [43, 50] + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + b := FromValues([]float32{5, 6, 7, 8}, 2, 2) + c := Matmul(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{19, 22, 43, 50}) +} + +func TestMatmul_VectorMatrix(t *testing.T) { + // [1 2 3] @ [[1],[2],[3]] = [14] + a := FromValues([]float32{1, 2, 3}, 1, 3) + b := FromValues([]float32{1, 2, 3}, 3, 1) + c := Matmul(a, b) + Materialize(c) + + if c.Size() != 1 { + t.Fatalf("size = %d, want 1", c.Size()) + } + if !approx(float64(c.Floats()[0]), 14.0) { + t.Errorf("result = %f, want 14.0", c.Floats()[0]) + } +} + +// --- Reductions --- + +func TestSoftmax(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 1, 3) + c := Softmax(a) + Materialize(c) + + got := c.Floats() + // softmax values should sum to 1 + sum := float64(0) + for _, v := range got { + sum += float64(v) + } + if !approx(sum, 1.0) { + t.Errorf("softmax sum = %f, want 1.0", sum) + } + // values should be monotonically increasing + if got[0] >= got[1] || got[1] >= got[2] { + t.Errorf("softmax not monotonic: %v", got) + } +} + +func TestArgmax(t *testing.T) { + a := FromValues([]float32{1, 5, 3, 2}, 1, 4) + c := Argmax(a, -1, false) + Materialize(c) + + if c.Int() != 1 { + t.Errorf("argmax = %d, want 1", c.Int()) + } +} + +func TestTopK(t *testing.T) { + a := FromValues([]float32{1, 5, 3, 7, 2}, 1, 5) + c := TopK(a, 2) + Materialize(c) + + got := c.Floats() + if len(got) != 2 { + t.Fatalf("topk returned %d elements, want 2", len(got)) + } + // Top-2 from {1,5,3,7,2} should contain 7 and 5 (order not guaranteed) + has7, has5 := false, false + for _, v := range got { + if v == 7 { + has7 = true + } + if v == 5 { + has5 = true + } + } + if !has7 || !has5 { + t.Errorf("topk = %v, want set {7, 5}", got) + } +} + +func TestSum(t *testing.T) { + // 2x3 matrix, sum along axis 1 + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + c := Sum(a, 1, false) + Materialize(c) + // row 0: 1+2+3=6, row 1: 4+5+6=15 + floatSliceApprox(t, c.Floats(), []float32{6, 15}) +} + +func TestSum_KeepDims(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + c := Sum(a, 1, true) + Materialize(c) + + if c.NumDims() != 2 { + t.Errorf("ndim = %d, want 2 (keepDims)", c.NumDims()) + } + shape := c.Shape() + if shape[0] != 2 || shape[1] != 1 { + t.Errorf("shape = %v, want [2 1]", shape) + } +} + +func TestMean(t *testing.T) { + a := FromValues([]float32{2, 4, 6, 8}, 2, 2) + c := Mean(a, 1, false) + Materialize(c) + // row 0: (2+4)/2=3, row 1: (6+8)/2=7 + floatSliceApprox(t, c.Floats(), []float32{3, 7}) +} + +func TestLogSumExp_Axis(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 1, 3) + c := LogSumExp(a, -1, false) + Materialize(c) + + // log(exp(1) + exp(2) + exp(3)) ≈ 3.4076 + want := math.Log(math.Exp(1) + math.Exp(2) + math.Exp(3)) + if !approx(c.Float(), want) { + t.Errorf("LogSumExp = %f, want %f", c.Float(), want) + } +} + +// --- Shape operations --- + +func TestReshape(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 6) + c := Reshape(a, 2, 3) + Materialize(c) + + shape := c.Shape() + if shape[0] != 2 || shape[1] != 3 { + t.Errorf("shape = %v, want [2 3]", shape) + } + // Data preserved + floatSliceApprox(t, c.Floats(), []float32{1, 2, 3, 4, 5, 6}) +} + +func TestTranspose(t *testing.T) { + // [[1 2 3], [4 5 6]] transposed -> shape [3 2] + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + c := Transpose(a) + Materialize(c) + + shape := c.Shape() + if shape[0] != 3 || shape[1] != 2 { + t.Errorf("shape = %v, want [3 2]", shape) + } + + // Verify values via Reshape (forces contiguous copy) + flat := Reshape(c, 6) + Materialize(flat) + floatSliceApprox(t, flat.Floats(), []float32{1, 4, 2, 5, 3, 6}) +} + +func TestTranspose_WithAxes(t *testing.T) { + // 3D: (2,3,4) with axes (0,2,1) -> (2,4,3) + data := make([]float32, 24) + for i := range data { + data[i] = float32(i) + } + a := FromValues(data, 2, 3, 4) + c := Transpose(a, 0, 2, 1) + Materialize(c) + + shape := c.Shape() + if shape[0] != 2 || shape[1] != 4 || shape[2] != 3 { + t.Errorf("shape = %v, want [2 4 3]", shape) + } +} + +func TestExpandDims(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 3) + c := ExpandDims(a, 0) + Materialize(c) + + shape := c.Shape() + if len(shape) != 2 || shape[0] != 1 || shape[1] != 3 { + t.Errorf("shape = %v, want [1 3]", shape) + } +} + +func TestSqueeze(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 1, 3) + c := Squeeze(a, 0) + Materialize(c) + + shape := c.Shape() + if len(shape) != 1 || shape[0] != 3 { + t.Errorf("shape = %v, want [3]", shape) + } +} + +func TestConcatenate(t *testing.T) { + a := FromValues([]float32{1, 2}, 2) + b := FromValues([]float32{3, 4, 5}, 3) + c := Concatenate([]*Array{a, b}, 0) + Materialize(c) + + if c.Size() != 5 { + t.Fatalf("size = %d, want 5", c.Size()) + } + floatSliceApprox(t, c.Floats(), []float32{1, 2, 3, 4, 5}) +} + +func TestBroadcastTo(t *testing.T) { + a := FromValues([]float32{1, 2, 3}, 1, 3) + c := BroadcastTo(a, []int32{4, 3}) + Materialize(c) + + shape := c.Shape() + if shape[0] != 4 || shape[1] != 3 { + t.Errorf("shape = %v, want [4 3]", shape) + } + if c.Size() != 12 { + t.Errorf("size = %d, want 12", c.Size()) + } + + // Verify via Reshape (forces contiguous copy for broadcast views) + flat := Reshape(c, 12) + Materialize(flat) + got := flat.Floats() + want := []float32{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3} + floatSliceApprox(t, got, want) +} + +func TestAsType(t *testing.T) { + a := FromValues([]float32{1.5, 2.7, 3.9}, 3) + c := AsType(a, DTypeInt32) + Materialize(c) + + if c.Dtype() != DTypeInt32 { + t.Errorf("dtype = %v, want int32", c.Dtype()) + } + got := c.DataInt32() + // Truncation to int + want := []int32{1, 2, 3} + for i := range got { + if got[i] != want[i] { + t.Errorf("[%d] = %d, want %d", i, got[i], want[i]) + } + } +} + +// --- Indexing --- + +func TestTake(t *testing.T) { + a := FromValues([]float32{10, 20, 30, 40, 50}, 5) + indices := FromValues([]int32{0, 2, 4}, 3) + c := Take(a, indices, 0) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{10, 30, 50}) +} + +func TestWhere(t *testing.T) { + cond := FromValues([]bool{true, false, true}, 3) + a := FromValues([]float32{1, 2, 3}, 3) + b := FromValues([]float32{4, 5, 6}, 3) + c := Where(cond, a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 5, 3}) +} + +func TestTakeAlongAxis(t *testing.T) { + // 2x3 matrix, pick one element per row along axis 1 + a := FromValues([]float32{10, 20, 30, 40, 50, 60}, 2, 3) + indices := FromValues([]int32{2, 0}, 2, 1) // row 0 pick col 2, row 1 pick col 0 + c := TakeAlongAxis(a, indices, 1) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{30, 40}) +} + +// --- Slicing --- + +func TestSlice(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + // Extract first row: [0:1, 0:3] + c := Slice(a, []int32{0, 0}, []int32{1, 3}) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{1, 2, 3}) +} + +func TestSliceAxis(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + // Slice columns 1:3 from all rows + c := SliceAxis(a, 1, 1, 3) + Materialize(c) + + shape := c.Shape() + if shape[0] != 2 || shape[1] != 2 { + t.Errorf("shape = %v, want [2 2]", shape) + } + // Reshape to force contiguous layout for value check + flat := Reshape(c, 4) + Materialize(flat) + floatSliceApprox(t, flat.Floats(), []float32{2, 3, 5, 6}) +} + +func TestSliceUpdateInplace(t *testing.T) { + a := Zeros([]int32{2, 3}, DTypeFloat32) + update := FromValues([]float32{7, 8, 9}, 1, 3) + // Put [7 8 9] in second row + c := SliceUpdateInplace(a, update, []int32{1, 0}, []int32{2, 3}) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{0, 0, 0, 7, 8, 9}) +} + +// --- Broadcasting arithmetic --- + +func TestAdd_Broadcasting(t *testing.T) { + // [2,3] + [1,3] should broadcast + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + b := FromValues([]float32{10, 20, 30}, 1, 3) + c := Add(a, b) + Materialize(c) + floatSliceApprox(t, c.Floats(), []float32{11, 22, 33, 14, 25, 36}) +} + +// --- Random --- + +func TestRandomCategorical(t *testing.T) { + // Heavily weighted towards index 2 + logprobs := FromValues([]float32{-100, -100, 0}, 1, 3) + sample := RandomCategorical(logprobs) + Materialize(sample) + + idx := sample.Int() + if idx != 2 { + t.Errorf("categorical sample = %d, want 2 (dominant logprob)", idx) + } +} + +func TestRandomUniform(t *testing.T) { + a := RandomUniform(0, 1, []int32{100}, DTypeFloat32) + Materialize(a) + + if a.Size() != 100 { + t.Fatalf("size = %d, want 100", a.Size()) + } + for i, v := range a.Floats() { + if v < 0 || v >= 1 { + t.Errorf("[%d] = %f, out of [0, 1) range", i, v) + } + } +}