diff --git a/internal/metal/array.go b/internal/metal/array.go index e580304..2e79840 100644 --- a/internal/metal/array.go +++ b/internal/metal/array.go @@ -218,28 +218,62 @@ func (t Array) Float() float64 { } } +// IsRowContiguous reports whether the array's physical memory layout is +// row-major contiguous. Non-contiguous arrays (from Transpose, BroadcastTo, +// SliceAxis, etc.) must be made contiguous before reading raw data. +func (t Array) IsRowContiguous() bool { + var res C.bool + C._mlx_array_is_row_contiguous(&res, t.ctx) + return bool(res) +} + +// Contiguous returns a row-major contiguous copy of the array. +// If the array is already row-contiguous, this is a no-op. +func Contiguous(a *Array) *Array { + out := New("CONTIGUOUS", a) + C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx) + return out +} + +// ensureContiguous returns a row-contiguous array, making a copy if needed. +// This must be called before any mlx_array_data_* access. +func ensureContiguous(a *Array) *Array { + if a.IsRowContiguous() { + return a + } + c := Contiguous(a) + Materialize(c) + return c +} + // Ints extracts all elements as int slice (from int32 data). -func (t Array) Ints() []int { - ints := make([]int, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { +// Automatically handles non-contiguous arrays (transpose, broadcast, slice views). +func (t *Array) Ints() []int { + src := ensureContiguous(t) + ints := make([]int, src.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(ints)) { ints[i] = int(f) } return ints } // DataInt32 extracts all elements as int32 slice. -func (t Array) DataInt32() []int32 { - data := make([]int32, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) { +// Automatically handles non-contiguous arrays (transpose, broadcast, slice views). +func (t *Array) DataInt32() []int32 { + src := ensureContiguous(t) + data := make([]int32, src.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(data)) { data[i] = int32(f) } return data } // Floats extracts all elements as float32 slice. -func (t Array) Floats() []float32 { - floats := make([]float32, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { +// Automatically handles non-contiguous arrays (transpose, broadcast, slice views). +func (t *Array) Floats() []float32 { + src := ensureContiguous(t) + floats := make([]float32, src.Size()) + for i, f := range unsafe.Slice(C.mlx_array_data_float32(src.ctx), len(floats)) { floats[i] = float32(f) } return floats diff --git a/internal/metal/array_test.go b/internal/metal/array_test.go index ff745ac..4871e56 100644 --- a/internal/metal/array_test.go +++ b/internal/metal/array_test.go @@ -321,6 +321,102 @@ func TestFree_NilSafe(t *testing.T) { } } +// --- Contiguous handling --- + +func TestIsRowContiguous_Fresh(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4}, 2, 2) + Materialize(a) + + if !a.IsRowContiguous() { + t.Error("freshly created array should be row-contiguous") + } +} + +func TestIsRowContiguous_Transposed(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + b := Transpose(a) + Materialize(b) + + if b.IsRowContiguous() { + t.Error("transposed array should not be row-contiguous") + } +} + +func TestContiguous_MakesContiguous(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + b := Transpose(a) // non-contiguous + c := Contiguous(b) + Materialize(c) + + if !c.IsRowContiguous() { + t.Error("Contiguous() result should be row-contiguous") + } + shape := c.Shape() + if shape[0] != 3 || shape[1] != 2 { + t.Errorf("shape = %v, want [3 2]", shape) + } +} + +func TestFloats_NonContiguous(t *testing.T) { + // [[1 2 3], [4 5 6]] transposed → [[1 4], [2 5], [3 6]] + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + b := Transpose(a) + Materialize(b) + + // Previously this returned wrong data without Reshape workaround + got := b.Floats() + want := []float32{1, 4, 2, 5, 3, 6} + for i := range got { + if got[i] != want[i] { + t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + +func TestDataInt32_NonContiguous(t *testing.T) { + a := FromValues([]int32{1, 2, 3, 4, 5, 6}, 2, 3) + b := Transpose(a) + Materialize(b) + + got := b.DataInt32() + want := []int32{1, 4, 2, 5, 3, 6} + for i := range got { + if got[i] != want[i] { + t.Errorf("DataInt32()[%d] = %d, want %d", i, got[i], want[i]) + } + } +} + +func TestFloats_BroadcastView(t *testing.T) { + // BroadcastTo creates a non-contiguous view + a := FromValues([]float32{1, 2, 3}, 1, 3) + b := BroadcastTo(a, []int32{2, 3}) + Materialize(b) + + got := b.Floats() + want := []float32{1, 2, 3, 1, 2, 3} + for i := range got { + if got[i] != want[i] { + t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + +func TestFloats_SliceView(t *testing.T) { + a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3) + // Slice columns 1:3 — creates a non-contiguous view + b := SliceAxis(a, 1, 1, 3) + Materialize(b) + + got := b.Floats() + want := []float32{2, 3, 5, 6} + for i := range got { + if got[i] != want[i] { + t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i]) + } + } +} + // --- Data extraction edge cases --- func TestArray_Ints(t *testing.T) {