cli/pkg/mlx/slice.go
Claude 8ee0c4bc4e feat: add native MLX backend for Apple Silicon inference (pkg/mlx)
CGo wrapper for mlx-c providing zero-Python Metal GPU inference.
Includes Gemma 3 model architecture, BPE tokenizer, KV cache,
composable sampling, and OpenAI-compatible serve command.

Build-tagged (darwin && arm64 && mlx) with stubs for cross-platform.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:53:52 +00:00

63 lines
1.7 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// Slice extracts a sub-array using start and end indices for each dimension.
// starts and ends must have the same length as the array's dimensions.
func Slice(a *Array, starts, ends []int32) *Array {
out := New("SLICE", a)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
return out
}
// SliceAxis extracts a sub-array along a single axis.
func SliceAxis(a *Array, axis int, start, end int32) *Array {
// Build full slice parameters
ndim := a.NumDims()
starts := make([]int32, ndim)
ends := make([]int32, ndim)
for i := 0; i < ndim; i++ {
starts[i] = 0
ends[i] = int32(a.Dim(i))
}
ax := axis
if ax < 0 {
ax = ndim + ax
}
starts[ax] = start
ends[ax] = end
return Slice(a, starts, ends)
}
// SliceUpdateInplace updates a slice of the array in-place.
// This is critical for KV cache updates.
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
out := New("SLICE_UPDATE", a, update)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
return out
}