diff --git a/pkg/mlx/compile.go b/pkg/mlx/compile.go index 4794270..f04d1dd 100644 --- a/pkg/mlx/compile.go +++ b/pkg/mlx/compile.go @@ -6,10 +6,10 @@ package mlx #include "mlx/c/mlx.h" // Callback for compiled functions. -extern void goCompiledFunc(mlx_vector_array inputs, mlx_vector_array outputs, void *payload); +extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); static mlx_closure new_closure(void *payload) { - return mlx_closure_new_func_payload(&goCompiledFunc, payload); + return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL); } */ import "C" @@ -29,11 +29,11 @@ type CompiledFunc struct { var compiledFuncs sync.Map //export goCompiledFunc -func goCompiledFunc(inputs C.mlx_vector_array, outputs C.mlx_vector_array, payload unsafe.Pointer) { +func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { id := uintptr(payload) fnI, ok := compiledFuncs.Load(id) if !ok { - return + return 1 } fn := fnI.(func([]*Array) []*Array) @@ -51,8 +51,9 @@ func goCompiledFunc(inputs C.mlx_vector_array, outputs C.mlx_vector_array, paylo // Set outputs for _, out := range goOutputs { - C.mlx_vector_array_append_value(outputs, out.ctx) + C.mlx_vector_array_append_value(*outputs, out.ctx) } + return 0 } var nextID uintptr