Phase 1 hardening: cover all previously-untested core operations. - array_test.go (25): scalar/array creation, shape, clone, free, data access - ops_test.go (44): arithmetic, math, matmul, reductions, shape ops, indexing, slicing, random - nn_test.go (8): Linear (dense/bias/LoRA), Embedding, RMSNormModule, RepeatKV - fast_test.go (9): RMSNorm, LayerNorm, RoPE, ScaledDotProductAttention Found: Floats()/DataInt32() return wrong data on non-contiguous arrays (transpose, broadcast, slice views). Documented in FINDINGS.md. Also: cpp/ workspace docs for CLion Claude session, Go 1.26 impact assessment, verified go generate → test round-trip (29→115 tests). Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
370 lines
7.4 KiB
Go
370 lines
7.4 KiB
Go
//go:build darwin && arm64
|
|
|
|
package mlx
|
|
|
|
import (
|
|
"math"
|
|
"testing"
|
|
)
|
|
|
|
// --- Scalar creation (FromValue) ---
|
|
|
|
func TestFromValue_Float32(t *testing.T) {
|
|
a := FromValue(float32(3.14))
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeFloat32 {
|
|
t.Errorf("dtype = %v, want float32", a.Dtype())
|
|
}
|
|
if a.NumDims() != 0 {
|
|
t.Errorf("ndim = %d, want 0 (scalar)", a.NumDims())
|
|
}
|
|
if a.Size() != 1 {
|
|
t.Errorf("size = %d, want 1", a.Size())
|
|
}
|
|
if math.Abs(a.Float()-3.14) > 1e-5 {
|
|
t.Errorf("value = %f, want 3.14", a.Float())
|
|
}
|
|
}
|
|
|
|
func TestFromValue_Float64(t *testing.T) {
|
|
a := FromValue(float64(2.718281828))
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeFloat64 {
|
|
t.Errorf("dtype = %v, want float64", a.Dtype())
|
|
}
|
|
if math.Abs(a.Float()-2.718281828) > 1e-8 {
|
|
t.Errorf("value = %f, want 2.718281828", a.Float())
|
|
}
|
|
}
|
|
|
|
func TestFromValue_Int(t *testing.T) {
|
|
a := FromValue(42)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeInt32 {
|
|
t.Errorf("dtype = %v, want int32", a.Dtype())
|
|
}
|
|
if a.Int() != 42 {
|
|
t.Errorf("value = %d, want 42", a.Int())
|
|
}
|
|
}
|
|
|
|
func TestFromValue_Bool(t *testing.T) {
|
|
a := FromValue(true)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeBool {
|
|
t.Errorf("dtype = %v, want bool", a.Dtype())
|
|
}
|
|
if a.Int() != 1 {
|
|
t.Errorf("value = %d, want 1 (true)", a.Int())
|
|
}
|
|
}
|
|
|
|
func TestFromValue_Complex64(t *testing.T) {
|
|
a := FromValue(complex64(3 + 4i))
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeComplex64 {
|
|
t.Errorf("dtype = %v, want complex64", a.Dtype())
|
|
}
|
|
if a.Size() != 1 {
|
|
t.Errorf("size = %d, want 1", a.Size())
|
|
}
|
|
}
|
|
|
|
// --- Slice creation (FromValues) ---
|
|
|
|
func TestFromValues_Float32_1D(t *testing.T) {
|
|
data := []float32{1.0, 2.0, 3.0, 4.0}
|
|
a := FromValues(data, 4)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeFloat32 {
|
|
t.Errorf("dtype = %v, want float32", a.Dtype())
|
|
}
|
|
if a.NumDims() != 1 {
|
|
t.Errorf("ndim = %d, want 1", a.NumDims())
|
|
}
|
|
if a.Dim(0) != 4 {
|
|
t.Errorf("dim(0) = %d, want 4", a.Dim(0))
|
|
}
|
|
if a.Size() != 4 {
|
|
t.Errorf("size = %d, want 4", a.Size())
|
|
}
|
|
|
|
got := a.Floats()
|
|
for i, want := range data {
|
|
if math.Abs(float64(got[i]-want)) > 1e-6 {
|
|
t.Errorf("element[%d] = %f, want %f", i, got[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFromValues_Float32_2D(t *testing.T) {
|
|
data := []float32{1, 2, 3, 4, 5, 6}
|
|
a := FromValues(data, 2, 3) // 2x3 matrix
|
|
Materialize(a)
|
|
|
|
if a.NumDims() != 2 {
|
|
t.Errorf("ndim = %d, want 2", a.NumDims())
|
|
}
|
|
shape := a.Shape()
|
|
if shape[0] != 2 || shape[1] != 3 {
|
|
t.Errorf("shape = %v, want [2 3]", shape)
|
|
}
|
|
if a.Size() != 6 {
|
|
t.Errorf("size = %d, want 6", a.Size())
|
|
}
|
|
|
|
got := a.Floats()
|
|
for i, want := range data {
|
|
if math.Abs(float64(got[i]-want)) > 1e-6 {
|
|
t.Errorf("element[%d] = %f, want %f", i, got[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFromValues_Int32(t *testing.T) {
|
|
data := []int32{10, 20, 30}
|
|
a := FromValues(data, 3)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeInt32 {
|
|
t.Errorf("dtype = %v, want int32", a.Dtype())
|
|
}
|
|
got := a.DataInt32()
|
|
for i, want := range data {
|
|
if got[i] != want {
|
|
t.Errorf("element[%d] = %d, want %d", i, got[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestFromValues_Int64(t *testing.T) {
|
|
data := []int64{100, 200, 300}
|
|
a := FromValues(data, 3)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeInt64 {
|
|
t.Errorf("dtype = %v, want int64", a.Dtype())
|
|
}
|
|
if a.Size() != 3 {
|
|
t.Errorf("size = %d, want 3", a.Size())
|
|
}
|
|
}
|
|
|
|
func TestFromValues_Bool(t *testing.T) {
|
|
data := []bool{true, false, true}
|
|
a := FromValues(data, 3)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeBool {
|
|
t.Errorf("dtype = %v, want bool", a.Dtype())
|
|
}
|
|
if a.Size() != 3 {
|
|
t.Errorf("size = %d, want 3", a.Size())
|
|
}
|
|
}
|
|
|
|
func TestFromValues_Uint8(t *testing.T) {
|
|
data := []uint8{0, 127, 255}
|
|
a := FromValues(data, 3)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeUint8 {
|
|
t.Errorf("dtype = %v, want uint8", a.Dtype())
|
|
}
|
|
}
|
|
|
|
func TestFromValues_PanicsWithoutShape(t *testing.T) {
|
|
defer func() {
|
|
if r := recover(); r == nil {
|
|
t.Error("expected panic when shape is missing")
|
|
}
|
|
}()
|
|
FromValues([]float32{1, 2, 3})
|
|
}
|
|
|
|
// --- Zeros ---
|
|
|
|
func TestZeros(t *testing.T) {
|
|
a := Zeros([]int32{2, 3}, DTypeFloat32)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeFloat32 {
|
|
t.Errorf("dtype = %v, want float32", a.Dtype())
|
|
}
|
|
shape := a.Shape()
|
|
if shape[0] != 2 || shape[1] != 3 {
|
|
t.Errorf("shape = %v, want [2 3]", shape)
|
|
}
|
|
if a.Size() != 6 {
|
|
t.Errorf("size = %d, want 6", a.Size())
|
|
}
|
|
|
|
for i, v := range a.Floats() {
|
|
if v != 0.0 {
|
|
t.Errorf("element[%d] = %f, want 0.0", i, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestZeros_Int32(t *testing.T) {
|
|
a := Zeros([]int32{4}, DTypeInt32)
|
|
Materialize(a)
|
|
|
|
if a.Dtype() != DTypeInt32 {
|
|
t.Errorf("dtype = %v, want int32", a.Dtype())
|
|
}
|
|
for i, v := range a.DataInt32() {
|
|
if v != 0 {
|
|
t.Errorf("element[%d] = %d, want 0", i, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
// --- Shape and metadata ---
|
|
|
|
func TestArray_Shape3D(t *testing.T) {
|
|
data := make([]float32, 24)
|
|
a := FromValues(data, 2, 3, 4)
|
|
Materialize(a)
|
|
|
|
if a.NumDims() != 3 {
|
|
t.Errorf("ndim = %d, want 3", a.NumDims())
|
|
}
|
|
dims := a.Dims()
|
|
if dims[0] != 2 || dims[1] != 3 || dims[2] != 4 {
|
|
t.Errorf("dims = %v, want [2 3 4]", dims)
|
|
}
|
|
if a.Size() != 24 {
|
|
t.Errorf("size = %d, want 24", a.Size())
|
|
}
|
|
if a.NumBytes() != 24*4 { // float32 = 4 bytes
|
|
t.Errorf("nbytes = %d, want %d", a.NumBytes(), 24*4)
|
|
}
|
|
}
|
|
|
|
// --- String representation ---
|
|
|
|
func TestArray_String(t *testing.T) {
|
|
a := FromValue(float32(42.0))
|
|
Materialize(a)
|
|
|
|
s := a.String()
|
|
if s == "" {
|
|
t.Error("String() returned empty")
|
|
}
|
|
// MLX prints "array(42, dtype=float32)" or similar
|
|
t.Logf("String() = %q", s)
|
|
}
|
|
|
|
// --- Clone and Set ---
|
|
|
|
func TestArray_Clone(t *testing.T) {
|
|
a := FromValue(float32(7.0))
|
|
b := a.Clone()
|
|
Materialize(a, b)
|
|
|
|
if math.Abs(b.Float()-7.0) > 1e-6 {
|
|
t.Errorf("clone value = %f, want 7.0", b.Float())
|
|
}
|
|
}
|
|
|
|
func TestArray_Set(t *testing.T) {
|
|
a := FromValue(float32(1.0))
|
|
b := FromValue(float32(2.0))
|
|
Materialize(a, b)
|
|
|
|
a.Set(b)
|
|
Materialize(a)
|
|
|
|
if math.Abs(a.Float()-2.0) > 1e-6 {
|
|
t.Errorf("after Set, value = %f, want 2.0", a.Float())
|
|
}
|
|
}
|
|
|
|
// --- Valid and Free ---
|
|
|
|
func TestArray_Valid(t *testing.T) {
|
|
a := FromValue(float32(1.0))
|
|
Materialize(a)
|
|
|
|
if !a.Valid() {
|
|
t.Error("expected Valid() = true for live array")
|
|
}
|
|
|
|
Free(a)
|
|
if a.Valid() {
|
|
t.Error("expected Valid() = false after Free")
|
|
}
|
|
}
|
|
|
|
func TestFree_ReturnsBytes(t *testing.T) {
|
|
a := FromValues([]float32{1, 2, 3, 4}, 4)
|
|
Materialize(a)
|
|
|
|
n := Free(a)
|
|
if n != 16 { // 4 * float32(4 bytes)
|
|
t.Errorf("Free returned %d bytes, want 16", n)
|
|
}
|
|
}
|
|
|
|
func TestFree_NilSafe(t *testing.T) {
|
|
// Should not panic on nil
|
|
n := Free(nil)
|
|
if n != 0 {
|
|
t.Errorf("Free(nil) returned %d, want 0", n)
|
|
}
|
|
}
|
|
|
|
// --- Data extraction edge cases ---
|
|
|
|
func TestArray_Ints(t *testing.T) {
|
|
data := []int32{10, 20, 30, 40}
|
|
a := FromValues(data, 4)
|
|
Materialize(a)
|
|
|
|
got := a.Ints()
|
|
for i, want := range []int{10, 20, 30, 40} {
|
|
if got[i] != want {
|
|
t.Errorf("Ints()[%d] = %d, want %d", i, got[i], want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestArray_Float_DTypeFloat32(t *testing.T) {
|
|
a := FromValue(float32(1.5))
|
|
Materialize(a)
|
|
|
|
got := a.Float()
|
|
if math.Abs(got-1.5) > 1e-6 {
|
|
t.Errorf("Float() = %f, want 1.5", got)
|
|
}
|
|
}
|
|
|
|
func TestArray_Float_DTypeFloat64(t *testing.T) {
|
|
a := FromValue(float64(1.5))
|
|
Materialize(a)
|
|
|
|
got := a.Float()
|
|
if math.Abs(got-1.5) > 1e-12 {
|
|
t.Errorf("Float() = %f, want 1.5", got)
|
|
}
|
|
}
|
|
|
|
// --- Collect ---
|
|
|
|
func TestCollect_FiltersNilAndInvalid(t *testing.T) {
|
|
a := FromValue(float32(1.0))
|
|
b := FromValue(float32(2.0))
|
|
Materialize(a, b)
|
|
|
|
result := Collect(a, nil, b)
|
|
if len(result) != 2 {
|
|
t.Errorf("Collect returned %d arrays, want 2", len(result))
|
|
}
|
|
}
|