feat/updates #1

Merged
Snider merged 51 commits from feat/updates into dev 2026-02-16 05:54:07 +00:00
6 changed files with 54 additions and 49 deletions
Showing only changes of commit 5e2d941b4d - Show all commits

View file

@ -135,7 +135,7 @@ func Zeros(shape []int32, dtype DType) *Array {
cShape[i] = C.int(s)
}
tt := New("ZEROS")
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
return tt
}

View file

@ -42,7 +42,7 @@ func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payl
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}

View file

@ -26,8 +26,9 @@ func LayerNorm(x, weight, bias *Array, eps float32) *Array {
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
out := New("FAST_ROPE", x)
freqs := C.mlx_array_new()
defer C.mlx_array_free(freqs)
C.mlx_fast_rope(
&out.ctx,
x.ctx,
@ -39,43 +40,40 @@ func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, off
},
C.float(scale),
C.int(offset),
freqs.ctx,
freqs,
DefaultStream().ctx,
)
return out
}
// ScaledDotProductAttention computes attention using a fused Metal kernel.
// mask can be nil for causal masking, or set causal=true for auto causal mask.
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
var mask, sinks *Array
mode := "none"
if causal {
mask = New("")
sinks = New("")
} else {
mask = New("")
sinks = New("")
}
mode := "causal"
if !causal {
mode = "none"
mode = "causal"
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
maskArr := C.mlx_array_new()
defer C.mlx_array_free(maskArr)
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
return out
}
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
sinks := New("")
cMode := C.CString("none")
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value, mask)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
return out
}

View file

@ -8,6 +8,8 @@ package mlx
*/
import "C"
import "unsafe"
// --- Element-wise arithmetic ---
// Add returns element-wise a + b.
@ -134,9 +136,13 @@ func Matmul(a, b *Array) *Array {
// QuantizedMatmul performs quantized matrix multiplication.
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
out := New("QMATMUL", x, w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("default")
defer C.free(unsafe.Pointer(mode))
C.mlx_quantized_matmul(
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
C._Bool(transpose), C.int(groupSize), C.int(bits),
C._Bool(transpose), gs, b, mode,
DefaultStream().ctx,
)
return out
@ -148,21 +154,21 @@ func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bit
func Softmax(a *Array) *Array {
out := New("SOFTMAX", a)
axis := []C.int{C.int(-1)}
C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx)
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
return out
}
// Argmax returns the index of the maximum value along an axis.
func Argmax(a *Array, axis int, keepDims bool) *Array {
out := New("ARGMAX", a)
C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// TopK returns the top k values along the last axis.
func TopK(a *Array, k int) *Array {
out := New("TOPK", a)
C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
return out
}
@ -170,7 +176,7 @@ func TopK(a *Array, k int) *Array {
func Sum(a *Array, axis int, keepDims bool) *Array {
out := New("SUM", a)
axes := []C.int{C.int(axis)}
C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
@ -178,7 +184,7 @@ func Sum(a *Array, axis int, keepDims bool) *Array {
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN", a)
axes := []C.int{C.int(axis)}
C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
@ -191,7 +197,7 @@ func Reshape(a *Array, shape ...int32) *Array {
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
return out
}
@ -199,13 +205,13 @@ func Reshape(a *Array, shape ...int32) *Array {
func Transpose(a *Array, axes ...int) *Array {
out := New("TRANSPOSE", a)
if len(axes) == 0 {
C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx)
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
} else {
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
}
return out
}
@ -213,8 +219,7 @@ func Transpose(a *Array, axes ...int) *Array {
// ExpandDims inserts a new axis at the given position.
func ExpandDims(a *Array, axis int) *Array {
out := New("EXPAND_DIMS", a)
axes := []C.int{C.int(axis)}
C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx)
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
@ -225,7 +230,7 @@ func Squeeze(a *Array, axes ...int) *Array {
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
@ -241,7 +246,7 @@ func Concatenate(arrays []*Array, axis int) *Array {
}
out := New("CONCAT", inputs...)
C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
@ -252,7 +257,7 @@ func BroadcastTo(a *Array, shape []int32) *Array {
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
return out
}
@ -270,11 +275,11 @@ func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.size_t, len(strides))
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.size_t(s)
cStrides[i] = C.int64_t(s)
}
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
return out
}
@ -295,7 +300,7 @@ func Where(condition, a, b *Array) *Array {
// Argpartition partially sorts and returns indices for top-k selection.
func Argpartition(a *Array, kth, axis int) *Array {
out := New("ARGPARTITION", a)
C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}

View file

@ -11,13 +11,13 @@ import "C"
// Returns indices sampled according to the log-probability distribution along the last axis.
func RandomCategorical(logprobs *Array) *Array {
out := New("RANDOM_CATEGORICAL", logprobs)
// shape for output: same as input but last dim removed
C.mlx_random_categorical_shape(
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_categorical(
&out.ctx,
logprobs.ctx,
C.int(-1), // axis
nil, C.int(0), // empty shape = infer from input
nil, // key (use default)
C.int(-1), // axis
key, // null key = use default RNG
DefaultStream().ctx,
)
return out
@ -32,12 +32,14 @@ func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
}
lo := FromValue(low)
hi := FromValue(high)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_uniform(
&out.ctx,
lo.ctx, hi.ctx,
&cShape[0], C.int(len(cShape)),
&cShape[0], C.size_t(len(cShape)),
C.mlx_dtype(dtype),
nil, // key
key,
DefaultStream().ctx,
)
return out

View file

@ -21,7 +21,7 @@ func Slice(a *Array, starts, ends []int32) *Array {
for i := range strides {
strides[i] = 1
}
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}
@ -58,6 +58,6 @@ func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
for i := range strides {
strides[i] = 1
}
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}