fix(metal): auto-contiguous data access for non-contiguous arrays
Bind mlx_contiguous and _mlx_array_is_row_contiguous from mlx-c. Floats(), DataInt32(), and Ints() now automatically handle non-contiguous arrays (from Transpose, BroadcastTo, SliceAxis, etc.) by checking IsRowContiguous() and making a contiguous copy when needed. Previously these methods returned silently wrong data for view arrays. The old workaround of Reshape(arr, totalSize) is no longer needed. 7 new tests for contiguous handling (transpose, broadcast, slice views). Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
bff97ccf19
commit
df0b300b1a
2 changed files with 139 additions and 9 deletions
|
|
@ -218,28 +218,62 @@ func (t Array) Float() float64 {
|
|||
}
|
||||
}
|
||||
|
||||
// IsRowContiguous reports whether the array's physical memory layout is
|
||||
// row-major contiguous. Non-contiguous arrays (from Transpose, BroadcastTo,
|
||||
// SliceAxis, etc.) must be made contiguous before reading raw data.
|
||||
func (t Array) IsRowContiguous() bool {
|
||||
var res C.bool
|
||||
C._mlx_array_is_row_contiguous(&res, t.ctx)
|
||||
return bool(res)
|
||||
}
|
||||
|
||||
// Contiguous returns a row-major contiguous copy of the array.
|
||||
// If the array is already row-contiguous, this is a no-op.
|
||||
func Contiguous(a *Array) *Array {
|
||||
out := New("CONTIGUOUS", a)
|
||||
C.mlx_contiguous(&out.ctx, a.ctx, C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// ensureContiguous returns a row-contiguous array, making a copy if needed.
|
||||
// This must be called before any mlx_array_data_* access.
|
||||
func ensureContiguous(a *Array) *Array {
|
||||
if a.IsRowContiguous() {
|
||||
return a
|
||||
}
|
||||
c := Contiguous(a)
|
||||
Materialize(c)
|
||||
return c
|
||||
}
|
||||
|
||||
// Ints extracts all elements as int slice (from int32 data).
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
|
||||
func (t *Array) Ints() []int {
|
||||
src := ensureContiguous(t)
|
||||
ints := make([]int, src.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
// DataInt32 extracts all elements as int32 slice.
|
||||
func (t Array) DataInt32() []int32 {
|
||||
data := make([]int32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
|
||||
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
|
||||
func (t *Array) DataInt32() []int32 {
|
||||
src := ensureContiguous(t)
|
||||
data := make([]int32, src.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(src.ctx), len(data)) {
|
||||
data[i] = int32(f)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// Floats extracts all elements as float32 slice.
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
// Automatically handles non-contiguous arrays (transpose, broadcast, slice views).
|
||||
func (t *Array) Floats() []float32 {
|
||||
src := ensureContiguous(t)
|
||||
floats := make([]float32, src.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(src.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
|
|
|
|||
|
|
@ -321,6 +321,102 @@ func TestFree_NilSafe(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Contiguous handling ---
|
||||
|
||||
func TestIsRowContiguous_Fresh(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
Materialize(a)
|
||||
|
||||
if !a.IsRowContiguous() {
|
||||
t.Error("freshly created array should be row-contiguous")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsRowContiguous_Transposed(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
|
||||
b := Transpose(a)
|
||||
Materialize(b)
|
||||
|
||||
if b.IsRowContiguous() {
|
||||
t.Error("transposed array should not be row-contiguous")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContiguous_MakesContiguous(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
|
||||
b := Transpose(a) // non-contiguous
|
||||
c := Contiguous(b)
|
||||
Materialize(c)
|
||||
|
||||
if !c.IsRowContiguous() {
|
||||
t.Error("Contiguous() result should be row-contiguous")
|
||||
}
|
||||
shape := c.Shape()
|
||||
if shape[0] != 3 || shape[1] != 2 {
|
||||
t.Errorf("shape = %v, want [3 2]", shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloats_NonContiguous(t *testing.T) {
|
||||
// [[1 2 3], [4 5 6]] transposed → [[1 4], [2 5], [3 6]]
|
||||
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
|
||||
b := Transpose(a)
|
||||
Materialize(b)
|
||||
|
||||
// Previously this returned wrong data without Reshape workaround
|
||||
got := b.Floats()
|
||||
want := []float32{1, 4, 2, 5, 3, 6}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataInt32_NonContiguous(t *testing.T) {
|
||||
a := FromValues([]int32{1, 2, 3, 4, 5, 6}, 2, 3)
|
||||
b := Transpose(a)
|
||||
Materialize(b)
|
||||
|
||||
got := b.DataInt32()
|
||||
want := []int32{1, 4, 2, 5, 3, 6}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("DataInt32()[%d] = %d, want %d", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloats_BroadcastView(t *testing.T) {
|
||||
// BroadcastTo creates a non-contiguous view
|
||||
a := FromValues([]float32{1, 2, 3}, 1, 3)
|
||||
b := BroadcastTo(a, []int32{2, 3})
|
||||
Materialize(b)
|
||||
|
||||
got := b.Floats()
|
||||
want := []float32{1, 2, 3, 1, 2, 3}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloats_SliceView(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4, 5, 6}, 2, 3)
|
||||
// Slice columns 1:3 — creates a non-contiguous view
|
||||
b := SliceAxis(a, 1, 1, 3)
|
||||
Materialize(b)
|
||||
|
||||
got := b.Floats()
|
||||
want := []float32{2, 3, 5, 6}
|
||||
for i := range got {
|
||||
if got[i] != want[i] {
|
||||
t.Errorf("Floats()[%d] = %f, want %f", i, got[i], want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Data extraction edge cases ---
|
||||
|
||||
func TestArray_Ints(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue