fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze, concatenate, argpartition - Fix size_t vs int for count parameters throughout - Fix int64_t strides in as_strided - Add mlx_optional_int + mode param to quantized_matmul - Use mlx_array_new() for null arrays (freqs, key, mask, sinks) - Fix expand_dims to single-axis signature - Fix compile callback signature (size_t index) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c25e1a633c
commit
a0435a84ea
6 changed files with 54 additions and 49 deletions
|
|
@ -135,7 +135,7 @@ func Zeros(shape []int32, dtype DType) *Array {
|
||||||
cShape[i] = C.int(s)
|
cShape[i] = C.int(s)
|
||||||
}
|
}
|
||||||
tt := New("ZEROS")
|
tt := New("ZEROS")
|
||||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||||
return tt
|
return tt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl
|
||||||
goInputs := make([]*Array, nInputs)
|
goInputs := make([]*Array, nInputs)
|
||||||
for i := 0; i < nInputs; i++ {
|
for i := 0; i < nInputs; i++ {
|
||||||
a := New("INPUT")
|
a := New("INPUT")
|
||||||
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
|
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||||
goInputs[i] = a
|
goInputs[i] = a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,9 @@ func LayerNorm(x, weight, bias *Array, eps float32) *Array {
|
||||||
|
|
||||||
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
||||||
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
||||||
freqs := New("")
|
out := New("FAST_ROPE", x)
|
||||||
out := New("FAST_ROPE", x, freqs)
|
freqs := C.mlx_array_new()
|
||||||
|
defer C.mlx_array_free(freqs)
|
||||||
C.mlx_fast_rope(
|
C.mlx_fast_rope(
|
||||||
&out.ctx,
|
&out.ctx,
|
||||||
x.ctx,
|
x.ctx,
|
||||||
|
|
@ -39,43 +40,40 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off
|
||||||
},
|
},
|
||||||
C.float(scale),
|
C.float(scale),
|
||||||
C.int(offset),
|
C.int(offset),
|
||||||
freqs.ctx,
|
freqs,
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
||||||
// mask can be nil for causal masking, or set causal=true for auto causal mask.
|
|
||||||
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
||||||
var mask, sinks *Array
|
mode := "none"
|
||||||
if causal {
|
if causal {
|
||||||
mask = New("")
|
mode = "causal"
|
||||||
sinks = New("")
|
|
||||||
} else {
|
|
||||||
mask = New("")
|
|
||||||
sinks = New("")
|
|
||||||
}
|
|
||||||
|
|
||||||
mode := "causal"
|
|
||||||
if !causal {
|
|
||||||
mode = "none"
|
|
||||||
}
|
}
|
||||||
cMode := C.CString(mode)
|
cMode := C.CString(mode)
|
||||||
defer C.free(unsafe.Pointer(cMode))
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
|
|
||||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
maskArr := C.mlx_array_new()
|
||||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
defer C.mlx_array_free(maskArr)
|
||||||
|
sinksArr := C.mlx_array_new()
|
||||||
|
defer C.mlx_array_free(sinksArr)
|
||||||
|
|
||||||
|
out := New("FAST_SDPA", query, key, value)
|
||||||
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
||||||
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
||||||
sinks := New("")
|
|
||||||
cMode := C.CString("none")
|
cMode := C.CString("none")
|
||||||
defer C.free(unsafe.Pointer(cMode))
|
defer C.free(unsafe.Pointer(cMode))
|
||||||
|
|
||||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
sinksArr := C.mlx_array_new()
|
||||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
defer C.mlx_array_free(sinksArr)
|
||||||
|
|
||||||
|
out := New("FAST_SDPA", query, key, value, mask)
|
||||||
|
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ package mlx
|
||||||
*/
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
// --- Element-wise arithmetic ---
|
// --- Element-wise arithmetic ---
|
||||||
|
|
||||||
// Add returns element-wise a + b.
|
// Add returns element-wise a + b.
|
||||||
|
|
@ -134,9 +136,13 @@ func Matmul(a, b *Array) *Array {
|
||||||
// QuantizedMatmul performs quantized matrix multiplication.
|
// QuantizedMatmul performs quantized matrix multiplication.
|
||||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
||||||
out := New("QMATMUL", x, w, scales, biases)
|
out := New("QMATMUL", x, w, scales, biases)
|
||||||
|
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||||
|
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||||
|
mode := C.CString("default")
|
||||||
|
defer C.free(unsafe.Pointer(mode))
|
||||||
C.mlx_quantized_matmul(
|
C.mlx_quantized_matmul(
|
||||||
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
||||||
C._Bool(transpose), C.int(groupSize), C.int(bits),
|
C._Bool(transpose), gs, b, mode,
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
@ -148,21 +154,21 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
||||||
func Softmax(a *Array) *Array {
|
func Softmax(a *Array) *Array {
|
||||||
out := New("SOFTMAX", a)
|
out := New("SOFTMAX", a)
|
||||||
axis := []C.int{C.int(-1)}
|
axis := []C.int{C.int(-1)}
|
||||||
C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx)
|
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// Argmax returns the index of the maximum value along an axis.
|
// Argmax returns the index of the maximum value along an axis.
|
||||||
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
||||||
out := New("ARGMAX", a)
|
out := New("ARGMAX", a)
|
||||||
C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// TopK returns the top k values along the last axis.
|
// TopK returns the top k values along the last axis.
|
||||||
func TopK(a *Array, k int) *Array {
|
func TopK(a *Array, k int) *Array {
|
||||||
out := New("TOPK", a)
|
out := New("TOPK", a)
|
||||||
C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -170,7 +176,7 @@ func TopK(a *Array, k int) *Array {
|
||||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||||
out := New("SUM", a)
|
out := New("SUM", a)
|
||||||
axes := []C.int{C.int(axis)}
|
axes := []C.int{C.int(axis)}
|
||||||
C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -178,7 +184,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||||
out := New("MEAN", a)
|
out := New("MEAN", a)
|
||||||
axes := []C.int{C.int(axis)}
|
axes := []C.int{C.int(axis)}
|
||||||
C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,7 +197,7 @@ func Reshape(a *Array, shape ...int32) *Array {
|
||||||
for i, s := range shape {
|
for i, s := range shape {
|
||||||
cShape[i] = C.int(s)
|
cShape[i] = C.int(s)
|
||||||
}
|
}
|
||||||
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -199,13 +205,13 @@ func Reshape(a *Array, shape ...int32) *Array {
|
||||||
func Transpose(a *Array, axes ...int) *Array {
|
func Transpose(a *Array, axes ...int) *Array {
|
||||||
out := New("TRANSPOSE", a)
|
out := New("TRANSPOSE", a)
|
||||||
if len(axes) == 0 {
|
if len(axes) == 0 {
|
||||||
C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx)
|
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||||
} else {
|
} else {
|
||||||
cAxes := make([]C.int, len(axes))
|
cAxes := make([]C.int, len(axes))
|
||||||
for i, ax := range axes {
|
for i, ax := range axes {
|
||||||
cAxes[i] = C.int(ax)
|
cAxes[i] = C.int(ax)
|
||||||
}
|
}
|
||||||
C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
@ -213,8 +219,7 @@ func Transpose(a *Array, axes ...int) *Array {
|
||||||
// ExpandDims inserts a new axis at the given position.
|
// ExpandDims inserts a new axis at the given position.
|
||||||
func ExpandDims(a *Array, axis int) *Array {
|
func ExpandDims(a *Array, axis int) *Array {
|
||||||
out := New("EXPAND_DIMS", a)
|
out := New("EXPAND_DIMS", a)
|
||||||
axes := []C.int{C.int(axis)}
|
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||||
C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx)
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -225,7 +230,7 @@ func Squeeze(a *Array, axes ...int) *Array {
|
||||||
for i, ax := range axes {
|
for i, ax := range axes {
|
||||||
cAxes[i] = C.int(ax)
|
cAxes[i] = C.int(ax)
|
||||||
}
|
}
|
||||||
C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -241,7 +246,7 @@ func Concatenate(arrays []*Array, axis int) *Array {
|
||||||
}
|
}
|
||||||
|
|
||||||
out := New("CONCAT", inputs...)
|
out := New("CONCAT", inputs...)
|
||||||
C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -252,7 +257,7 @@ func BroadcastTo(a *Array, shape []int32) *Array {
|
||||||
for i, s := range shape {
|
for i, s := range shape {
|
||||||
cShape[i] = C.int(s)
|
cShape[i] = C.int(s)
|
||||||
}
|
}
|
||||||
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -270,11 +275,11 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
||||||
for i, s := range shape {
|
for i, s := range shape {
|
||||||
cShape[i] = C.int(s)
|
cShape[i] = C.int(s)
|
||||||
}
|
}
|
||||||
cStrides := make([]C.size_t, len(strides))
|
cStrides := make([]C.int64_t, len(strides))
|
||||||
for i, s := range strides {
|
for i, s := range strides {
|
||||||
cStrides[i] = C.size_t(s)
|
cStrides[i] = C.int64_t(s)
|
||||||
}
|
}
|
||||||
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -295,7 +300,7 @@ func Where(condition, a, b *Array) *Array {
|
||||||
// Argpartition partially sorts and returns indices for top-k selection.
|
// Argpartition partially sorts and returns indices for top-k selection.
|
||||||
func Argpartition(a *Array, kth, axis int) *Array {
|
func Argpartition(a *Array, kth, axis int) *Array {
|
||||||
out := New("ARGPARTITION", a)
|
out := New("ARGPARTITION", a)
|
||||||
C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ import "C"
|
||||||
// Returns indices sampled according to the log-probability distribution along the last axis.
|
// Returns indices sampled according to the log-probability distribution along the last axis.
|
||||||
func RandomCategorical(logprobs *Array) *Array {
|
func RandomCategorical(logprobs *Array) *Array {
|
||||||
out := New("RANDOM_CATEGORICAL", logprobs)
|
out := New("RANDOM_CATEGORICAL", logprobs)
|
||||||
// shape for output: same as input but last dim removed
|
key := C.mlx_array_new()
|
||||||
C.mlx_random_categorical_shape(
|
defer C.mlx_array_free(key)
|
||||||
|
C.mlx_random_categorical(
|
||||||
&out.ctx,
|
&out.ctx,
|
||||||
logprobs.ctx,
|
logprobs.ctx,
|
||||||
C.int(-1), // axis
|
C.int(-1), // axis
|
||||||
nil, C.int(0), // empty shape = infer from input
|
key, // null key = use default RNG
|
||||||
nil, // key (use default)
|
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
@ -32,12 +32,14 @@ func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
|
||||||
}
|
}
|
||||||
lo := FromValue(low)
|
lo := FromValue(low)
|
||||||
hi := FromValue(high)
|
hi := FromValue(high)
|
||||||
|
key := C.mlx_array_new()
|
||||||
|
defer C.mlx_array_free(key)
|
||||||
C.mlx_random_uniform(
|
C.mlx_random_uniform(
|
||||||
&out.ctx,
|
&out.ctx,
|
||||||
lo.ctx, hi.ctx,
|
lo.ctx, hi.ctx,
|
||||||
&cShape[0], C.int(len(cShape)),
|
&cShape[0], C.size_t(len(cShape)),
|
||||||
C.mlx_dtype(dtype),
|
C.mlx_dtype(dtype),
|
||||||
nil, // key
|
key,
|
||||||
DefaultStream().ctx,
|
DefaultStream().ctx,
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ func Slice(a *Array, starts, ends []int32) *Array {
|
||||||
for i := range strides {
|
for i := range strides {
|
||||||
strides[i] = 1
|
strides[i] = 1
|
||||||
}
|
}
|
||||||
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,6 +58,6 @@ func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
|
||||||
for i := range strides {
|
for i := range strides {
|
||||||
strides[i] = 1
|
strides[i] = 1
|
||||||
}
|
}
|
||||||
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue