go/pkg/mlx/compile.go
Claude 9d664c055a
feat: add native MLX backend for Apple Silicon inference (pkg/mlx)
CGo wrapper for mlx-c providing zero-Python Metal GPU inference.
Includes Gemma 3 model architecture, BPE tokenizer, KV cache,
composable sampling, and OpenAI-compatible serve command.

Build-tagged (darwin && arm64 && mlx) with stubs for cross-platform.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 01:19:04 +00:00

85 lines
2 KiB
Go

//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for compiled functions.
extern void goCompiledFunc(mlx_vector_array inputs, mlx_vector_array outputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(inputs C.mlx_vector_array, outputs C.mlx_vector_array, payload unsafe.Pointer) {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// Set outputs
for _, out := range goOutputs {
C.mlx_vector_array_append_value(outputs, out.ctx)
}
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
}
// Call executes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}