fix(metal): auto-contiguous data access for non-contiguous arrays

Bind mlx_contiguous and _mlx_array_is_row_contiguous from mlx-c.
Floats(), DataInt32(), and Ints() now automatically handle non-contiguous
arrays (from Transpose, BroadcastTo, SliceAxis, etc.) by checking
IsRowContiguous() and making a contiguous copy when needed.

Previously these methods returned silently wrong data for view arrays.
The old workaround of Reshape(arr, totalSize) is no longer needed.

7 new tests for contiguous handling (transpose, broadcast, slice views).

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 20:39:36 +00:00
parent bff97ccf19
commit df0b300b1a
2 changed files with 139 additions and 9 deletions

View file

@ -218,28 +218,62 @@ func (t Array) Float() float64 {
}
}
// IsRowContiguous reports whether the array's physical memory layout is
// row-major contiguous. Non-contiguous arrays (from Transpose, BroadcastTo,
// SliceAxis, etc.) must be made contiguous before reading raw data.
func (t Array) IsRowContiguous() bool {
var res C.bool
C._mlx_array_is_row_contiguous(&res, t.ctx)
return bool(res)
}
// Contiguous returns a row-major contiguous copy of the array.
// If the array is already row-contiguous, this is a no-op.
func Contiguous(a *Array) *Array {
out := New("CONTIGUOUS", a)
C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx)
return out
}
// ensureContiguous returns a row-contiguous array, making a copy if needed.
// This must be called before any mlx_array_data_* access.
func ensureContiguous(a *Array) *Array {
if a.IsRowContiguous() {
return a
}
c := Contiguous(a)
Materialize(c)
return c
}
// Ints extracts all elements as int slice (from int32 data).
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
func (t *Array) Ints() []int {
src := ensureContiguous(t)
ints := make([]int, src.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
// DataInt32 extracts all elements as int32 slice.
func (t Array) DataInt32() []int32 {
data := make([]int32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
func (t *Array) DataInt32() []int32 {
src := ensureContiguous(t)
data := make([]int32, src.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(data)) {
data[i] = int32(f)
}
return data
}
// Floats extracts all elements as float32 slice.
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
func (t *Array) Floats() []float32 {
src := ensureContiguous(t)
floats := make([]float32, src.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(src.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats

View file

@ -321,6 +321,102 @@ func TestFree_NilSafe(t *testing.T) {
}
}
// --- Contiguous handling ---
func TestIsRowContiguous_Fresh(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
Materialize(a)
if !a.IsRowContiguous() {
t.Error("freshly created array should be row-contiguous")
}
}
func TestIsRowContiguous_Transposed(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
b := Transpose(a)
Materialize(b)
if b.IsRowContiguous() {
t.Error("transposed array should not be row-contiguous")
}
}
func TestContiguous_MakesContiguous(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
b := Transpose(a) // non-contiguous
c := Contiguous(b)
Materialize(c)
if !c.IsRowContiguous() {
t.Error("Contiguous() result should be row-contiguous")
}
shape := c.Shape()
if shape[0] != 3 || shape[1] != 2 {
t.Errorf("shape = %v, want [3 2]", shape)
}
}
func TestFloats_NonContiguous(t *testing.T) {
// [[1 2 3], [4 5 6]] transposed → [[1 4], [2 5], [3 6]]
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
b := Transpose(a)
Materialize(b)
// Previously this returned wrong data without Reshape workaround
got := b.Floats()
want := []float32{1, 4, 2, 5, 3, 6}
for i := range got {
if got[i] != want[i] {
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
}
}
}
func TestDataInt32_NonContiguous(t *testing.T) {
a := FromValues([]int32{1, 2, 3, 4, 5, 6}, 2, 3)
b := Transpose(a)
Materialize(b)
got := b.DataInt32()
want := []int32{1, 4, 2, 5, 3, 6}
for i := range got {
if got[i] != want[i] {
t.Errorf("DataInt32()[%d] = %d, want %d", i, got[i], want[i])
}
}
}
func TestFloats_BroadcastView(t *testing.T) {
// BroadcastTo creates a non-contiguous view
a := FromValues([]float32{1, 2, 3}, 1, 3)
b := BroadcastTo(a, []int32{2, 3})
Materialize(b)
got := b.Floats()
want := []float32{1, 2, 3, 1, 2, 3}
for i := range got {
if got[i] != want[i] {
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
}
}
}
func TestFloats_SliceView(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
// Slice columns 1:3 — creates a non-contiguous view
b := SliceAxis(a, 1, 1, 3)
Materialize(b)
got := b.Floats()
want := []float32{2, 3, 5, 6}
for i := range got {
if got[i] != want[i] {
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
}
}
}
// --- Data extraction edge cases ---
func TestArray_Ints(t *testing.T) {