fix: correct 20 mlx-c API mismatches for v0.4.1
- Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze, concatenate, argpartition - Fix size_t vs int for count parameters throughout - Fix int64_t strides in as_strided - Add mlx_optional_int + mode param to quantized_matmul - Use mlx_array_new() for null arrays (freqs, key, mask, sinks) - Fix expand_dims to single-axis signature - Fix compile callback signature (size_t index) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c25e1a633c
commit
a0435a84ea
6 changed files with 54 additions and 49 deletions
|
|
@ -135,7 +135,7 @@ func Zeros(shape []int32, dtype DType) *Array {
|
|||
cShape[i] = C.int(s)
|
||||
}
|
||||
tt := New("ZEROS")
|
||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl
|
|||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,8 +26,9 @@ func LayerNorm(x, weight, bias *Array, eps float32) *Array {
|
|||
|
||||
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
||||
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", x, freqs)
|
||||
out := New("FAST_ROPE", x)
|
||||
freqs := C.mlx_array_new()
|
||||
defer C.mlx_array_free(freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
|
|
@ -39,43 +40,40 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off
|
|||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
freqs,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
||||
// mask can be nil for causal masking, or set causal=true for auto causal mask.
|
||||
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
||||
var mask, sinks *Array
|
||||
mode := "none"
|
||||
if causal {
|
||||
mask = New("")
|
||||
sinks = New("")
|
||||
} else {
|
||||
mask = New("")
|
||||
sinks = New("")
|
||||
}
|
||||
|
||||
mode := "causal"
|
||||
if !causal {
|
||||
mode = "none"
|
||||
mode = "causal"
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
maskArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(maskArr)
|
||||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
||||
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString("none")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ package mlx
|
|||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// --- Element-wise arithmetic ---
|
||||
|
||||
// Add returns element-wise a + b.
|
||||
|
|
@ -134,9 +136,13 @@ func Matmul(a, b *Array) *Array {
|
|||
// QuantizedMatmul performs quantized matrix multiplication.
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
||||
out := New("QMATMUL", x, w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("default")
|
||||
defer C.free(unsafe.Pointer(mode))
|
||||
C.mlx_quantized_matmul(
|
||||
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
||||
C._Bool(transpose), C.int(groupSize), C.int(bits),
|
||||
C._Bool(transpose), gs, b, mode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
|
|
@ -148,21 +154,21 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
|
|||
func Softmax(a *Array) *Array {
|
||||
out := New("SOFTMAX", a)
|
||||
axis := []C.int{C.int(-1)}
|
||||
C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx)
|
||||
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argmax returns the index of the maximum value along an axis.
|
||||
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", a)
|
||||
C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TopK returns the top k values along the last axis.
|
||||
func TopK(a *Array, k int) *Array {
|
||||
out := New("TOPK", a)
|
||||
C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -170,7 +176,7 @@ func TopK(a *Array, k int) *Array {
|
|||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("SUM", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -178,7 +184,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array {
|
|||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +197,7 @@ func Reshape(a *Array, shape ...int32) *Array {
|
|||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
||||
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -199,13 +205,13 @@ func Reshape(a *Array, shape ...int32) *Array {
|
|||
func Transpose(a *Array, axes ...int) *Array {
|
||||
out := New("TRANSPOSE", a)
|
||||
if len(axes) == 0 {
|
||||
C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
} else {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
||||
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
@ -213,8 +219,7 @@ func Transpose(a *Array, axes ...int) *Array {
|
|||
// ExpandDims inserts a new axis at the given position.
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
out := New("EXPAND_DIMS", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx)
|
||||
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -225,7 +230,7 @@ func Squeeze(a *Array, axes ...int) *Array {
|
|||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
||||
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -241,7 +246,7 @@ func Concatenate(arrays []*Array, axis int) *Array {
|
|||
}
|
||||
|
||||
out := New("CONCAT", inputs...)
|
||||
C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -252,7 +257,7 @@ func BroadcastTo(a *Array, shape []int32) *Array {
|
|||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
||||
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -270,11 +275,11 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
|||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
cStrides := make([]C.size_t, len(strides))
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.size_t(s)
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
||||
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -295,7 +300,7 @@ func Where(condition, a, b *Array) *Array {
|
|||
// Argpartition partially sorts and returns indices for top-k selection.
|
||||
func Argpartition(a *Array, kth, axis int) *Array {
|
||||
out := New("ARGPARTITION", a)
|
||||
C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ import "C"
|
|||
// Returns indices sampled according to the log-probability distribution along the last axis.
|
||||
func RandomCategorical(logprobs *Array) *Array {
|
||||
out := New("RANDOM_CATEGORICAL", logprobs)
|
||||
// shape for output: same as input but last dim removed
|
||||
C.mlx_random_categorical_shape(
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_categorical(
|
||||
&out.ctx,
|
||||
logprobs.ctx,
|
||||
C.int(-1), // axis
|
||||
nil, C.int(0), // empty shape = infer from input
|
||||
nil, // key (use default)
|
||||
C.int(-1), // axis
|
||||
key, // null key = use default RNG
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
|
|
@ -32,12 +32,14 @@ func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
|
|||
}
|
||||
lo := FromValue(low)
|
||||
hi := FromValue(high)
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_uniform(
|
||||
&out.ctx,
|
||||
lo.ctx, hi.ctx,
|
||||
&cShape[0], C.int(len(cShape)),
|
||||
&cShape[0], C.size_t(len(cShape)),
|
||||
C.mlx_dtype(dtype),
|
||||
nil, // key
|
||||
key,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ func Slice(a *Array, starts, ends []int32) *Array {
|
|||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
||||
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
|
|
@ -58,6 +58,6 @@ func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
|
|||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
||||
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue