2026-02-16 01:19:04 +00:00
|
|
|
//go:build darwin && arm64 && mlx
|
|
|
|
|
|
|
|
|
|
package mlx
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
#include <stdlib.h>
|
|
|
|
|
#include "mlx/c/mlx.h"
|
|
|
|
|
*/
|
|
|
|
|
import "C"
|
|
|
|
|
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
import "unsafe"
|
|
|
|
|
|
2026-02-16 01:19:04 +00:00
|
|
|
// --- Element-wise arithmetic ---
|
|
|
|
|
|
|
|
|
|
// Add returns element-wise a + b.
|
|
|
|
|
func Add(a, b *Array) *Array {
|
|
|
|
|
out := New("ADD", a, b)
|
|
|
|
|
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// AddScalar returns a + scalar (broadcast).
|
|
|
|
|
func AddScalar(a *Array, s float32) *Array {
|
|
|
|
|
scalar := FromValue(s)
|
|
|
|
|
return Add(a, scalar)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Mul returns element-wise a * b.
|
|
|
|
|
func Mul(a, b *Array) *Array {
|
|
|
|
|
out := New("MUL", a, b)
|
|
|
|
|
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// MulScalar returns a * scalar (broadcast).
|
|
|
|
|
func MulScalar(a *Array, s float32) *Array {
|
|
|
|
|
scalar := FromValue(s)
|
|
|
|
|
return Mul(a, scalar)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Divide returns element-wise a / b.
|
|
|
|
|
func Divide(a, b *Array) *Array {
|
|
|
|
|
out := New("DIV", a, b)
|
|
|
|
|
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Subtract returns element-wise a - b.
|
|
|
|
|
func Subtract(a, b *Array) *Array {
|
|
|
|
|
out := New("SUB", a, b)
|
|
|
|
|
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Negative returns element-wise -a.
|
|
|
|
|
func Negative(a *Array) *Array {
|
|
|
|
|
out := New("NEG", a)
|
|
|
|
|
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Math functions ---
|
|
|
|
|
|
|
|
|
|
// Exp returns element-wise exp(a).
|
|
|
|
|
func Exp(a *Array) *Array {
|
|
|
|
|
out := New("EXP", a)
|
|
|
|
|
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Tanh returns element-wise tanh(a).
|
|
|
|
|
func Tanh(a *Array) *Array {
|
|
|
|
|
out := New("TANH", a)
|
|
|
|
|
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sqrt returns element-wise sqrt(a).
|
|
|
|
|
func Sqrt(a *Array) *Array {
|
|
|
|
|
out := New("SQRT", a)
|
|
|
|
|
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Rsqrt returns element-wise 1/sqrt(a).
|
|
|
|
|
func Rsqrt(a *Array) *Array {
|
|
|
|
|
out := New("RSQRT", a)
|
|
|
|
|
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Reciprocal returns element-wise 1/a.
|
|
|
|
|
func Reciprocal(a *Array) *Array {
|
|
|
|
|
out := New("RECIPROCAL", a)
|
|
|
|
|
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Square returns element-wise a^2.
|
|
|
|
|
func Square(a *Array) *Array {
|
|
|
|
|
out := New("SQUARE", a)
|
|
|
|
|
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Power returns element-wise a^b.
|
|
|
|
|
func Power(a, b *Array) *Array {
|
|
|
|
|
out := New("POWER", a, b)
|
|
|
|
|
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Maximum returns element-wise max(a, b).
|
|
|
|
|
func Maximum(a, b *Array) *Array {
|
|
|
|
|
out := New("MAX", a, b)
|
|
|
|
|
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Minimum returns element-wise min(a, b).
|
|
|
|
|
func Minimum(a, b *Array) *Array {
|
|
|
|
|
out := New("MIN", a, b)
|
|
|
|
|
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Matrix operations ---
|
|
|
|
|
|
|
|
|
|
// Matmul returns the matrix product of a and b.
|
|
|
|
|
func Matmul(a, b *Array) *Array {
|
|
|
|
|
out := New("MATMUL", a, b)
|
|
|
|
|
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// QuantizedMatmul performs quantized matrix multiplication.
|
|
|
|
|
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
|
|
|
|
out := New("QMATMUL", x, w, scales, biases)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
|
|
|
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
|
|
|
|
mode := C.CString("default")
|
|
|
|
|
defer C.free(unsafe.Pointer(mode))
|
2026-02-16 01:19:04 +00:00
|
|
|
C.mlx_quantized_matmul(
|
|
|
|
|
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C._Bool(transpose), gs, b, mode,
|
2026-02-16 01:19:04 +00:00
|
|
|
DefaultStream().ctx,
|
|
|
|
|
)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Reductions ---
|
|
|
|
|
|
|
|
|
|
// Softmax returns softmax along the last axis.
|
|
|
|
|
func Softmax(a *Array) *Array {
|
|
|
|
|
out := New("SOFTMAX", a)
|
|
|
|
|
axis := []C.int{C.int(-1)}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Argmax returns the index of the maximum value along an axis.
|
|
|
|
|
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
|
|
|
|
out := New("ARGMAX", a)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TopK returns the top k values along the last axis.
|
|
|
|
|
func TopK(a *Array, k int) *Array {
|
|
|
|
|
out := New("TOPK", a)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sum reduces by summation along the given axis.
|
|
|
|
|
func Sum(a *Array, axis int, keepDims bool) *Array {
|
|
|
|
|
out := New("SUM", a)
|
|
|
|
|
axes := []C.int{C.int(axis)}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Mean reduces by averaging along the given axis.
|
|
|
|
|
func Mean(a *Array, axis int, keepDims bool) *Array {
|
|
|
|
|
out := New("MEAN", a)
|
|
|
|
|
axes := []C.int{C.int(axis)}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// --- Shape operations ---
|
|
|
|
|
|
|
|
|
|
// Reshape changes the shape of an array.
|
|
|
|
|
func Reshape(a *Array, shape ...int32) *Array {
|
|
|
|
|
out := New("RESHAPE", a)
|
|
|
|
|
cShape := make([]C.int, len(shape))
|
|
|
|
|
for i, s := range shape {
|
|
|
|
|
cShape[i] = C.int(s)
|
|
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Transpose permutes dimensions. If no axes given, reverses all dims.
|
|
|
|
|
func Transpose(a *Array, axes ...int) *Array {
|
|
|
|
|
out := New("TRANSPOSE", a)
|
|
|
|
|
if len(axes) == 0 {
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
} else {
|
|
|
|
|
cAxes := make([]C.int, len(axes))
|
|
|
|
|
for i, ax := range axes {
|
|
|
|
|
cAxes[i] = C.int(ax)
|
|
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
}
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ExpandDims inserts a new axis at the given position.
|
|
|
|
|
func ExpandDims(a *Array, axis int) *Array {
|
|
|
|
|
out := New("EXPAND_DIMS", a)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Squeeze removes dimensions of size 1.
|
|
|
|
|
func Squeeze(a *Array, axes ...int) *Array {
|
|
|
|
|
out := New("SQUEEZE", a)
|
|
|
|
|
cAxes := make([]C.int, len(axes))
|
|
|
|
|
for i, ax := range axes {
|
|
|
|
|
cAxes[i] = C.int(ax)
|
|
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Concatenate joins arrays along the given axis.
|
|
|
|
|
func Concatenate(arrays []*Array, axis int) *Array {
|
|
|
|
|
vector := C.mlx_vector_array_new()
|
|
|
|
|
defer C.mlx_vector_array_free(vector)
|
|
|
|
|
|
|
|
|
|
inputs := make([]*Array, len(arrays))
|
|
|
|
|
for i, a := range arrays {
|
|
|
|
|
C.mlx_vector_array_append_value(vector, a.ctx)
|
|
|
|
|
inputs[i] = a
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out := New("CONCAT", inputs...)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BroadcastTo broadcasts an array to the given shape.
|
|
|
|
|
func BroadcastTo(a *Array, shape []int32) *Array {
|
|
|
|
|
out := New("BROADCAST", a)
|
|
|
|
|
cShape := make([]C.int, len(shape))
|
|
|
|
|
for i, s := range shape {
|
|
|
|
|
cShape[i] = C.int(s)
|
|
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// AsType casts an array to a different dtype.
|
|
|
|
|
func AsType(a *Array, dtype DType) *Array {
|
|
|
|
|
out := New("ASTYPE", a)
|
|
|
|
|
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// AsStrided creates a view with custom strides.
|
|
|
|
|
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
|
|
|
|
out := New("AS_STRIDED", a)
|
|
|
|
|
cShape := make([]C.int, len(shape))
|
|
|
|
|
for i, s := range shape {
|
|
|
|
|
cShape[i] = C.int(s)
|
|
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
cStrides := make([]C.int64_t, len(strides))
|
2026-02-16 01:19:04 +00:00
|
|
|
for i, s := range strides {
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
cStrides[i] = C.int64_t(s)
|
2026-02-16 01:19:04 +00:00
|
|
|
}
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Take gathers elements from a along axis using indices.
|
|
|
|
|
func Take(a, indices *Array, axis int) *Array {
|
|
|
|
|
out := New("TAKE", a, indices)
|
|
|
|
|
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Where selects elements from a or b based on condition.
|
|
|
|
|
func Where(condition, a, b *Array) *Array {
|
|
|
|
|
out := New("WHERE", condition, a, b)
|
|
|
|
|
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Argpartition partially sorts and returns indices for top-k selection.
|
|
|
|
|
func Argpartition(a *Array, kth, axis int) *Array {
|
|
|
|
|
out := New("ARGPARTITION", a)
|
fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze,
concatenate, argpartition
- Fix size_t vs int for count parameters throughout
- Fix int64_t strides in as_strided
- Add mlx_optional_int + mode param to quantized_matmul
- Use mlx_array_new() for null arrays (freqs, key, mask, sinks)
- Fix expand_dims to single-axis signature
- Fix compile callback signature (size_t index)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:52:29 +00:00
|
|
|
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
2026-02-16 01:19:04 +00:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-16 02:12:31 +00:00
|
|
|
// Dequantize restores a quantized array to full precision.
|
|
|
|
|
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
|
|
|
|
out := New("DEQUANTIZE", w, scales, biases)
|
|
|
|
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
|
|
|
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
|
|
|
|
mode := C.CString("default")
|
|
|
|
|
defer C.free(unsafe.Pointer(mode))
|
|
|
|
|
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
|
|
|
|
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
2026-02-16 01:19:04 +00:00
|
|
|
// PutAlongAxis places values into array at indices along axis.
|
|
|
|
|
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
|
|
|
|
out := New("PUT_ALONG_AXIS", a, indices, values)
|
|
|
|
|
// Use scatter approach: src[indices] = values
|
|
|
|
|
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
|
|
|
|
return out
|
|
|
|
|
}
|