go-mlx/array_test.go
Snider 37abc496ba test(core): add 86 tests for ops, array, nn, fast kernels
Phase 1 hardening: cover all previously-untested core operations.

- array_test.go (25): scalar/array creation, shape, clone, free, data access
- ops_test.go (44): arithmetic, math, matmul, reductions, shape ops, indexing, slicing, random
- nn_test.go (8): Linear (dense/bias/LoRA), Embedding, RMSNormModule, RepeatKV
- fast_test.go (9): RMSNorm, LayerNorm, RoPE, ScaledDotProductAttention

Found: Floats()/DataInt32() return wrong data on non-contiguous arrays
(transpose, broadcast, slice views). Documented in FINDINGS.md.

Also: cpp/ workspace docs for CLion Claude session, Go 1.26 impact
assessment, verified go generate → test round-trip (29→115 tests).

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 18:37:30 +00:00

370 lines
7.4 KiB
Go

//go:build darwin && arm64
package mlx
import (
"math"
"testing"
)
// --- Scalar creation (FromValue) ---
func TestFromValue_Float32(t *testing.T) {
a := FromValue(float32(3.14))
Materialize(a)
if a.Dtype() != DTypeFloat32 {
t.Errorf("dtype = %v, want float32", a.Dtype())
}
if a.NumDims() != 0 {
t.Errorf("ndim = %d, want 0 (scalar)", a.NumDims())
}
if a.Size() != 1 {
t.Errorf("size = %d, want 1", a.Size())
}
if math.Abs(a.Float()-3.14) > 1e-5 {
t.Errorf("value = %f, want 3.14", a.Float())
}
}
func TestFromValue_Float64(t *testing.T) {
a := FromValue(float64(2.718281828))
Materialize(a)
if a.Dtype() != DTypeFloat64 {
t.Errorf("dtype = %v, want float64", a.Dtype())
}
if math.Abs(a.Float()-2.718281828) > 1e-8 {
t.Errorf("value = %f, want 2.718281828", a.Float())
}
}
func TestFromValue_Int(t *testing.T) {
a := FromValue(42)
Materialize(a)
if a.Dtype() != DTypeInt32 {
t.Errorf("dtype = %v, want int32", a.Dtype())
}
if a.Int() != 42 {
t.Errorf("value = %d, want 42", a.Int())
}
}
func TestFromValue_Bool(t *testing.T) {
a := FromValue(true)
Materialize(a)
if a.Dtype() != DTypeBool {
t.Errorf("dtype = %v, want bool", a.Dtype())
}
if a.Int() != 1 {
t.Errorf("value = %d, want 1 (true)", a.Int())
}
}
func TestFromValue_Complex64(t *testing.T) {
a := FromValue(complex64(3 + 4i))
Materialize(a)
if a.Dtype() != DTypeComplex64 {
t.Errorf("dtype = %v, want complex64", a.Dtype())
}
if a.Size() != 1 {
t.Errorf("size = %d, want 1", a.Size())
}
}
// --- Slice creation (FromValues) ---
func TestFromValues_Float32_1D(t *testing.T) {
data := []float32{1.0, 2.0, 3.0, 4.0}
a := FromValues(data, 4)
Materialize(a)
if a.Dtype() != DTypeFloat32 {
t.Errorf("dtype = %v, want float32", a.Dtype())
}
if a.NumDims() != 1 {
t.Errorf("ndim = %d, want 1", a.NumDims())
}
if a.Dim(0) != 4 {
t.Errorf("dim(0) = %d, want 4", a.Dim(0))
}
if a.Size() != 4 {
t.Errorf("size = %d, want 4", a.Size())
}
got := a.Floats()
for i, want := range data {
if math.Abs(float64(got[i]-want)) > 1e-6 {
t.Errorf("element[%d] = %f, want %f", i, got[i], want)
}
}
}
func TestFromValues_Float32_2D(t *testing.T) {
data := []float32{1, 2, 3, 4, 5, 6}
a := FromValues(data, 2, 3) // 2x3 matrix
Materialize(a)
if a.NumDims() != 2 {
t.Errorf("ndim = %d, want 2", a.NumDims())
}
shape := a.Shape()
if shape[0] != 2 || shape[1] != 3 {
t.Errorf("shape = %v, want [2 3]", shape)
}
if a.Size() != 6 {
t.Errorf("size = %d, want 6", a.Size())
}
got := a.Floats()
for i, want := range data {
if math.Abs(float64(got[i]-want)) > 1e-6 {
t.Errorf("element[%d] = %f, want %f", i, got[i], want)
}
}
}
func TestFromValues_Int32(t *testing.T) {
data := []int32{10, 20, 30}
a := FromValues(data, 3)
Materialize(a)
if a.Dtype() != DTypeInt32 {
t.Errorf("dtype = %v, want int32", a.Dtype())
}
got := a.DataInt32()
for i, want := range data {
if got[i] != want {
t.Errorf("element[%d] = %d, want %d", i, got[i], want)
}
}
}
func TestFromValues_Int64(t *testing.T) {
data := []int64{100, 200, 300}
a := FromValues(data, 3)
Materialize(a)
if a.Dtype() != DTypeInt64 {
t.Errorf("dtype = %v, want int64", a.Dtype())
}
if a.Size() != 3 {
t.Errorf("size = %d, want 3", a.Size())
}
}
func TestFromValues_Bool(t *testing.T) {
data := []bool{true, false, true}
a := FromValues(data, 3)
Materialize(a)
if a.Dtype() != DTypeBool {
t.Errorf("dtype = %v, want bool", a.Dtype())
}
if a.Size() != 3 {
t.Errorf("size = %d, want 3", a.Size())
}
}
func TestFromValues_Uint8(t *testing.T) {
data := []uint8{0, 127, 255}
a := FromValues(data, 3)
Materialize(a)
if a.Dtype() != DTypeUint8 {
t.Errorf("dtype = %v, want uint8", a.Dtype())
}
}
func TestFromValues_PanicsWithoutShape(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Error("expected panic when shape is missing")
}
}()
FromValues([]float32{1, 2, 3})
}
// --- Zeros ---
func TestZeros(t *testing.T) {
a := Zeros([]int32{2, 3}, DTypeFloat32)
Materialize(a)
if a.Dtype() != DTypeFloat32 {
t.Errorf("dtype = %v, want float32", a.Dtype())
}
shape := a.Shape()
if shape[0] != 2 || shape[1] != 3 {
t.Errorf("shape = %v, want [2 3]", shape)
}
if a.Size() != 6 {
t.Errorf("size = %d, want 6", a.Size())
}
for i, v := range a.Floats() {
if v != 0.0 {
t.Errorf("element[%d] = %f, want 0.0", i, v)
}
}
}
func TestZeros_Int32(t *testing.T) {
a := Zeros([]int32{4}, DTypeInt32)
Materialize(a)
if a.Dtype() != DTypeInt32 {
t.Errorf("dtype = %v, want int32", a.Dtype())
}
for i, v := range a.DataInt32() {
if v != 0 {
t.Errorf("element[%d] = %d, want 0", i, v)
}
}
}
// --- Shape and metadata ---
func TestArray_Shape3D(t *testing.T) {
data := make([]float32, 24)
a := FromValues(data, 2, 3, 4)
Materialize(a)
if a.NumDims() != 3 {
t.Errorf("ndim = %d, want 3", a.NumDims())
}
dims := a.Dims()
if dims[0] != 2 || dims[1] != 3 || dims[2] != 4 {
t.Errorf("dims = %v, want [2 3 4]", dims)
}
if a.Size() != 24 {
t.Errorf("size = %d, want 24", a.Size())
}
if a.NumBytes() != 24*4 { // float32 = 4 bytes
t.Errorf("nbytes = %d, want %d", a.NumBytes(), 24*4)
}
}
// --- String representation ---
func TestArray_String(t *testing.T) {
a := FromValue(float32(42.0))
Materialize(a)
s := a.String()
if s == "" {
t.Error("String() returned empty")
}
// MLX prints "array(42, dtype=float32)" or similar
t.Logf("String() = %q", s)
}
// --- Clone and Set ---
func TestArray_Clone(t *testing.T) {
a := FromValue(float32(7.0))
b := a.Clone()
Materialize(a, b)
if math.Abs(b.Float()-7.0) > 1e-6 {
t.Errorf("clone value = %f, want 7.0", b.Float())
}
}
func TestArray_Set(t *testing.T) {
a := FromValue(float32(1.0))
b := FromValue(float32(2.0))
Materialize(a, b)
a.Set(b)
Materialize(a)
if math.Abs(a.Float()-2.0) > 1e-6 {
t.Errorf("after Set, value = %f, want 2.0", a.Float())
}
}
// --- Valid and Free ---
func TestArray_Valid(t *testing.T) {
a := FromValue(float32(1.0))
Materialize(a)
if !a.Valid() {
t.Error("expected Valid() = true for live array")
}
Free(a)
if a.Valid() {
t.Error("expected Valid() = false after Free")
}
}
func TestFree_ReturnsBytes(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 4)
Materialize(a)
n := Free(a)
if n != 16 { // 4 * float32(4 bytes)
t.Errorf("Free returned %d bytes, want 16", n)
}
}
func TestFree_NilSafe(t *testing.T) {
// Should not panic on nil
n := Free(nil)
if n != 0 {
t.Errorf("Free(nil) returned %d, want 0", n)
}
}
// --- Data extraction edge cases ---
func TestArray_Ints(t *testing.T) {
data := []int32{10, 20, 30, 40}
a := FromValues(data, 4)
Materialize(a)
got := a.Ints()
for i, want := range []int{10, 20, 30, 40} {
if got[i] != want {
t.Errorf("Ints()[%d] = %d, want %d", i, got[i], want)
}
}
}
func TestArray_Float_DTypeFloat32(t *testing.T) {
a := FromValue(float32(1.5))
Materialize(a)
got := a.Float()
if math.Abs(got-1.5) > 1e-6 {
t.Errorf("Float() = %f, want 1.5", got)
}
}
func TestArray_Float_DTypeFloat64(t *testing.T) {
a := FromValue(float64(1.5))
Materialize(a)
got := a.Float()
if math.Abs(got-1.5) > 1e-12 {
t.Errorf("Float() = %f, want 1.5", got)
}
}
// --- Collect ---
func TestCollect_FiltersNilAndInvalid(t *testing.T) {
a := FromValue(float32(1.0))
b := FromValue(float32(2.0))
Materialize(a, b)
result := Collect(a, nil, b)
if len(result) != 2 {
t.Errorf("Collect returned %d arrays, want 2", len(result))
}
}