go/pkg/mlx/slice.go
Claude 065b42a0be
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
  concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00

63 lines
1.8 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// Slice extracts a sub-array using start and end indices for each dimension.
// starts and ends must have the same length as the array's dimensions.
func Slice(a *Array, starts, ends []int32) *Array {
out := New("SLICE", a)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}
// SliceAxis extracts a sub-array along a single axis.
func SliceAxis(a *Array, axis int, start, end int32) *Array {
// Build full slice parameters
ndim := a.NumDims()
starts := make([]int32, ndim)
ends := make([]int32, ndim)
for i := 0; i < ndim; i++ {
starts[i] = 0
ends[i] = int32(a.Dim(i))
}
ax := axis
if ax < 0 {
ax = ndim + ax
}
starts[ax] = start
ends[ax] = end
return Slice(a, starts, ends)
}
// SliceUpdateInplace updates a slice of the array in-place.
// This is critical for KV cache updates.
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
out := New("SLICE_UPDATE", a, update)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}