From a0435a84ea9f28914ec1909295640be66cc9af01 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 16 Feb 2026 01:52:29 +0000 Subject: [PATCH] fix: correct 20 mlx-c API mismatches for v0.4.1 - Use _axis/_axes variants for softmax, argmax, topk, sum, mean, squeeze, concatenate, argpartition - Fix size_t vs int for count parameters throughout - Fix int64_t strides in as_strided - Add mlx_optional_int + mode param to quantized_matmul - Use mlx_array_new() for null arrays (freqs, key, mask, sinks) - Fix expand_dims to single-axis signature - Fix compile callback signature (size_t index) Co-Authored-By: Claude Opus 4.6 --- pkg/mlx/array.go | 2 +- pkg/mlx/compile.go | 2 +- pkg/mlx/fast.go | 38 ++++++++++++++++++-------------------- pkg/mlx/ops.go | 41 +++++++++++++++++++++++------------------ pkg/mlx/random.go | 16 +++++++++------- pkg/mlx/slice.go | 4 ++-- 6 files changed, 54 insertions(+), 49 deletions(-) diff --git a/pkg/mlx/array.go b/pkg/mlx/array.go index 7b990eb0..091dab82 100644 --- a/pkg/mlx/array.go +++ b/pkg/mlx/array.go @@ -135,7 +135,7 @@ func Zeros(shape []int32, dtype DType) *Array { cShape[i] = C.int(s) } tt := New("ZEROS") - C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx) + C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx) return tt } diff --git a/pkg/mlx/compile.go b/pkg/mlx/compile.go index f04d1dda..7727344a 100644 --- a/pkg/mlx/compile.go +++ b/pkg/mlx/compile.go @@ -42,7 +42,7 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl goInputs := make([]*Array, nInputs) for i := 0; i < nInputs; i++ { a := New("INPUT") - C.mlx_vector_array_get(&a.ctx, inputs, C.int(i)) + C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) goInputs[i] = a } diff --git a/pkg/mlx/fast.go b/pkg/mlx/fast.go index f04c931f..58e9e5e4 100644 --- a/pkg/mlx/fast.go +++ b/pkg/mlx/fast.go @@ -26,8 +26,9 @@ func LayerNorm(x, weight, bias *Array, eps float32) *Array { // RoPE applies Rotary Position Embeddings using a fused Metal kernel. func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array { - freqs := New("") - out := New("FAST_ROPE", x, freqs) + out := New("FAST_ROPE", x) + freqs := C.mlx_array_new() + defer C.mlx_array_free(freqs) C.mlx_fast_rope( &out.ctx, x.ctx, @@ -39,43 +40,40 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off }, C.float(scale), C.int(offset), - freqs.ctx, + freqs, DefaultStream().ctx, ) return out } // ScaledDotProductAttention computes attention using a fused Metal kernel. -// mask can be nil for causal masking, or set causal=true for auto causal mask. func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array { - var mask, sinks *Array + mode := "none" if causal { - mask = New("") - sinks = New("") - } else { - mask = New("") - sinks = New("") - } - - mode := "causal" - if !causal { - mode = "none" + mode = "causal" } cMode := C.CString(mode) defer C.free(unsafe.Pointer(cMode)) - out := New("FAST_SDPA", query, key, value, mask, sinks) - C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) + maskArr := C.mlx_array_new() + defer C.mlx_array_free(maskArr) + sinksArr := C.mlx_array_new() + defer C.mlx_array_free(sinksArr) + + out := New("FAST_SDPA", query, key, value) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx) return out } // ScaledDotProductAttentionWithMask computes attention with an explicit mask. func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array { - sinks := New("") cMode := C.CString("none") defer C.free(unsafe.Pointer(cMode)) - out := New("FAST_SDPA", query, key, value, mask, sinks) - C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx) + sinksArr := C.mlx_array_new() + defer C.mlx_array_free(sinksArr) + + out := New("FAST_SDPA", query, key, value, mask) + C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx) return out } diff --git a/pkg/mlx/ops.go b/pkg/mlx/ops.go index 3e3bada3..c9ba959e 100644 --- a/pkg/mlx/ops.go +++ b/pkg/mlx/ops.go @@ -8,6 +8,8 @@ package mlx */ import "C" +import "unsafe" + // --- Element-wise arithmetic --- // Add returns element-wise a + b. @@ -134,9 +136,13 @@ func Matmul(a, b *Array) *Array { // QuantizedMatmul performs quantized matrix multiplication. func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array { out := New("QMATMUL", x, w, scales, biases) + gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} + b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} + mode := C.CString("default") + defer C.free(unsafe.Pointer(mode)) C.mlx_quantized_matmul( &out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx, - C._Bool(transpose), C.int(groupSize), C.int(bits), + C._Bool(transpose), gs, b, mode, DefaultStream().ctx, ) return out @@ -148,21 +154,21 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit func Softmax(a *Array) *Array { out := New("SOFTMAX", a) axis := []C.int{C.int(-1)} - C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx) + C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx) return out } // Argmax returns the index of the maximum value along an axis. func Argmax(a *Array, axis int, keepDims bool) *Array { out := New("ARGMAX", a) - C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) + C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) return out } // TopK returns the top k values along the last axis. func TopK(a *Array, k int) *Array { out := New("TOPK", a) - C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx) + C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx) return out } @@ -170,7 +176,7 @@ func TopK(a *Array, k int) *Array { func Sum(a *Array, axis int, keepDims bool) *Array { out := New("SUM", a) axes := []C.int{C.int(axis)} - C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx) + C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) return out } @@ -178,7 +184,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array { func Mean(a *Array, axis int, keepDims bool) *Array { out := New("MEAN", a) axes := []C.int{C.int(axis)} - C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx) + C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) return out } @@ -191,7 +197,7 @@ func Reshape(a *Array, shape ...int32) *Array { for i, s := range shape { cShape[i] = C.int(s) } - C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx) + C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) return out } @@ -199,13 +205,13 @@ func Reshape(a *Array, shape ...int32) *Array { func Transpose(a *Array, axes ...int) *Array { out := New("TRANSPOSE", a) if len(axes) == 0 { - C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx) + C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx) } else { cAxes := make([]C.int, len(axes)) for i, ax := range axes { cAxes[i] = C.int(ax) } - C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx) + C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) } return out } @@ -213,8 +219,7 @@ func Transpose(a *Array, axes ...int) *Array { // ExpandDims inserts a new axis at the given position. func ExpandDims(a *Array, axis int) *Array { out := New("EXPAND_DIMS", a) - axes := []C.int{C.int(axis)} - C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx) + C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) return out } @@ -225,7 +230,7 @@ func Squeeze(a *Array, axes ...int) *Array { for i, ax := range axes { cAxes[i] = C.int(ax) } - C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx) + C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) return out } @@ -241,7 +246,7 @@ func Concatenate(arrays []*Array, axis int) *Array { } out := New("CONCAT", inputs...) - C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx) + C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) return out } @@ -252,7 +257,7 @@ func BroadcastTo(a *Array, shape []int32) *Array { for i, s := range shape { cShape[i] = C.int(s) } - C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx) + C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) return out } @@ -270,11 +275,11 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { for i, s := range shape { cShape[i] = C.int(s) } - cStrides := make([]C.size_t, len(strides)) + cStrides := make([]C.int64_t, len(strides)) for i, s := range strides { - cStrides[i] = C.size_t(s) + cStrides[i] = C.int64_t(s) } - C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx) + C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx) return out } @@ -295,7 +300,7 @@ func Where(condition, a, b *Array) *Array { // Argpartition partially sorts and returns indices for top-k selection. func Argpartition(a *Array, kth, axis int) *Array { out := New("ARGPARTITION", a) - C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) + C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) return out } diff --git a/pkg/mlx/random.go b/pkg/mlx/random.go index e9b48fd4..bfadada5 100644 --- a/pkg/mlx/random.go +++ b/pkg/mlx/random.go @@ -11,13 +11,13 @@ import "C" // Returns indices sampled according to the log-probability distribution along the last axis. func RandomCategorical(logprobs *Array) *Array { out := New("RANDOM_CATEGORICAL", logprobs) - // shape for output: same as input but last dim removed - C.mlx_random_categorical_shape( + key := C.mlx_array_new() + defer C.mlx_array_free(key) + C.mlx_random_categorical( &out.ctx, logprobs.ctx, - C.int(-1), // axis - nil, C.int(0), // empty shape = infer from input - nil, // key (use default) + C.int(-1), // axis + key, // null key = use default RNG DefaultStream().ctx, ) return out @@ -32,12 +32,14 @@ func RandomUniform(low, high float32, shape []int32, dtype DType) *Array { } lo := FromValue(low) hi := FromValue(high) + key := C.mlx_array_new() + defer C.mlx_array_free(key) C.mlx_random_uniform( &out.ctx, lo.ctx, hi.ctx, - &cShape[0], C.int(len(cShape)), + &cShape[0], C.size_t(len(cShape)), C.mlx_dtype(dtype), - nil, // key + key, DefaultStream().ctx, ) return out diff --git a/pkg/mlx/slice.go b/pkg/mlx/slice.go index 9c3fdd43..da5ff743 100644 --- a/pkg/mlx/slice.go +++ b/pkg/mlx/slice.go @@ -21,7 +21,7 @@ func Slice(a *Array, starts, ends []int32) *Array { for i := range strides { strides[i] = 1 } - C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx) + C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) return out } @@ -58,6 +58,6 @@ func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array { for i := range strides { strides[i] = 1 } - C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx) + C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) return out }