refactor: extract mlx/ to standalone core/go-mlx module

Remove mlx/ directory (now lives at forge.lthn.ai/core/go-mlx).
Update ml/backend_mlx.go imports to reference the new module.
Add replace directive for local development.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 18:24:32 +00:00
parent 906a535899
commit 34d0f9ce41
27 changed files with 8 additions and 4387 deletions

3
go.mod
View file

@ -4,6 +4,7 @@ go 1.25.5
require (
forge.lthn.ai/core/go v0.0.0
forge.lthn.ai/core/go-mlx v0.0.0
github.com/gorilla/websocket v1.5.3
github.com/marcboeker/go-duckdb v1.8.5
github.com/modelcontextprotocol/go-sdk v1.3.0
@ -54,3 +55,5 @@ require (
)
replace forge.lthn.ai/core/go => ../core
replace forge.lthn.ai/core/go-mlx => ../go-mlx

View file

@ -10,11 +10,11 @@ import (
"strings"
"sync"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/model"
"forge.lthn.ai/core/go-ai/mlx/sample"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
"forge.lthn.ai/core/go-mlx"
"forge.lthn.ai/core/go-mlx/cache"
"forge.lthn.ai/core/go-mlx/model"
"forge.lthn.ai/core/go-mlx/sample"
"forge.lthn.ai/core/go-mlx/tokenizer"
)
// MLXBackend implements Backend and StreamingBackend for native Metal inference.

View file

@ -1,28 +0,0 @@
cmake_minimum_required(VERSION 3.24)
project(mlx)
set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0" CACHE STRING "Minimum macOS version")
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

View file

@ -1,261 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"encoding/binary"
"reflect"
"runtime"
"strings"
"unsafe"
)
// Array wraps an mlx_array handle.
// Memory management relies on Go GC finalizers to call mlx_array_free,
// which decrements MLX-C's internal reference count. MLX-C handles all
// cross-array references internally — the Go wrapper does not track them.
type Array struct {
ctx C.mlx_array
name string // debug label
}
// New creates a named Array and registers a GC finalizer.
// The inputs parameter is accepted for API compatibility but not stored —
// MLX-C tracks inter-array references via its own refcounting.
func New(name string, inputs ...*Array) *Array {
t := &Array{name: name}
runtime.SetFinalizer(t, finalizeArray)
return t
}
// finalizeArray is called by Go GC to release the underlying C array handle.
func finalizeArray(t *Array) {
if t != nil && t.ctx.ctx != nil {
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("mlx: unsupported scalar type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
// FromValues creates an Array from a Go slice with the given shape.
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
Init()
if len(shape) == 0 {
panic("mlx: shape required for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("mlx: unsupported element type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
// Zeros creates a zero-filled Array with the given shape and dtype.
func Zeros(shape []int32, dtype DType) *Array {
Init()
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
tt := New("ZEROS")
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
return tt
}
// Set replaces this array's C handle with another's.
func (t *Array) Set(other *Array) {
C.mlx_array_set(&t.ctx, other.ctx)
}
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
func (t *Array) Clone() *Array {
tt := New(t.name)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// Valid reports whether this Array has a non-nil mlx handle.
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
// String returns a human-readable representation of the array.
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
// Shape returns the dimensions as int32 slice.
func (t *Array) Shape() []int32 {
dims := make([]int32, t.NumDims())
for i := range dims {
dims[i] = int32(t.Dim(i))
}
return dims
}
// Size returns the total number of elements.
func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) }
// NumBytes returns the total byte size.
func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) }
// NumDims returns the number of dimensions.
func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) }
// Dim returns the size of dimension i.
func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) }
// Dims returns all dimensions as int slice.
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
// Dtype returns the array's data type.
func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) }
// Int extracts a scalar int64 value.
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
// Float extracts a scalar float64 value.
// Handles both float32 and float64 array dtypes.
func (t Array) Float() float64 {
switch t.Dtype() {
case DTypeFloat32:
var item C.float
C.mlx_array_item_float32(&item, t.ctx)
return float64(item)
default:
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
}
// Ints extracts all elements as int slice (from int32 data).
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
// DataInt32 extracts all elements as int32 slice.
func (t Array) DataInt32() []int32 {
data := make([]int32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
data[i] = int32(f)
}
return data
}
// Floats extracts all elements as float32 slice.
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
// Free explicitly releases C array handles. Does not cascade — MLX-C's
// internal refcounting handles dependent arrays automatically.
func Free(s ...*Array) int {
var n int
for _, t := range s {
if t != nil && t.Valid() {
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
runtime.SetFinalizer(t, nil) // cancel finalizer
}
}
return n
}

201
mlx/cache/cache.go vendored
View file

@ -1,201 +0,0 @@
//go:build darwin && arm64
// Package cache provides KV cache implementations for transformer inference.
package cache
import "forge.lthn.ai/core/go-ai/mlx"
// Cache manages key-value pairs for transformer attention layers.
type Cache interface {
// Update adds new key/value tensors and returns the full cached K/V.
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
// Offset returns the total number of tokens processed.
Offset() int
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
Len() int
// State returns the cached K/V arrays, or nil if empty.
State() []*mlx.Array
// Reset clears the cache for a new generation session.
Reset()
}
// KVCache implements an unbounded cache that grows as needed.
// Pre-allocates in chunks of `step` tokens to reduce allocations.
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
// NewKVCache creates a new unbounded KV cache with 256-token chunks.
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed.
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
func (c *KVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
}
// RotatingKVCache implements a bounded sliding window cache.
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
// NewRotatingKVCache creates a cache bounded to maxSize tokens.
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
if len(shape) < 4 {
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset++
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
if len(shape) < 4 {
// K/V must be [B, H, L, D] — if not, pass through unchanged
if c.keys == nil {
c.keys, c.values = k, v
}
c.offset += seqLen
return c.keys, c.values
}
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
func (c *RotatingKVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
c.idx = 0
}

View file

@ -1,90 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for compiled functions.
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr. Create a valid vector,
// fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
}
// Call executes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}

View file

@ -1,83 +0,0 @@
//go:build darwin && arm64
package mlx
// #include "mlx/c/mlx.h"
import "C"
import "encoding/json"
// DType represents an MLX array data type.
type DType C.mlx_dtype
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)
var dtypeNames = map[DType]string{
DTypeBool: "bool",
DTypeUint8: "uint8",
DTypeUint16: "uint16",
DTypeUint32: "uint32",
DTypeUint64: "uint64",
DTypeInt8: "int8",
DTypeInt16: "int16",
DTypeInt32: "int32",
DTypeInt64: "int64",
DTypeFloat16: "float16",
DTypeFloat32: "float32",
DTypeFloat64: "float64",
DTypeBFloat16: "bfloat16",
DTypeComplex64: "complex64",
}
func (d DType) String() string {
if s, ok := dtypeNames[d]; ok {
return s
}
return "unknown"
}
var dtypeFromString = map[string]DType{
"bool": DTypeBool, "BOOL": DTypeBool,
"uint8": DTypeUint8, "U8": DTypeUint8,
"uint16": DTypeUint16, "U16": DTypeUint16,
"uint32": DTypeUint32, "U32": DTypeUint32,
"uint64": DTypeUint64, "U64": DTypeUint64,
"int8": DTypeInt8, "I8": DTypeInt8,
"int16": DTypeInt16, "I16": DTypeInt16,
"int32": DTypeInt32, "I32": DTypeInt32,
"int64": DTypeInt64, "I64": DTypeInt64,
"float16": DTypeFloat16, "F16": DTypeFloat16,
"float32": DTypeFloat32, "F32": DTypeFloat32,
"float64": DTypeFloat64, "F64": DTypeFloat64,
"bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16,
"complex64": DTypeComplex64,
}
// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc.
func (d *DType) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
if dt, ok := dtypeFromString[s]; ok {
*d = dt
return nil
}
*d = DTypeFloat32 // default
return nil
}

View file

@ -1,79 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import "unsafe"
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
func RMSNorm(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// LayerNorm applies Layer normalization using a fused Metal kernel.
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
out := New("FAST_ROPE", x)
freqs := C.mlx_array_new()
defer C.mlx_array_free(freqs)
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C._Bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C._Bool(base != 0),
},
C.float(scale),
C.int(offset),
freqs,
DefaultStream().ctx,
)
return out
}
// ScaledDotProductAttention computes attention using a fused Metal kernel.
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
mode := ""
if causal {
mode = "causal"
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
maskArr := C.mlx_array_new()
defer C.mlx_array_free(maskArr)
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
return out
}
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
cMode := C.CString("array")
defer C.free(unsafe.Pointer(cMode))
sinksArr := C.mlx_array_new()
defer C.mlx_array_free(sinksArr)
out := New("FAST_SDPA", query, key, value, mask)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
return out
}

View file

@ -1,353 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for gradient closures — same signature as goCompiledFunc.
extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
static mlx_closure new_grad_closure(void *payload) {
return mlx_closure_new_func_payload(&goGradFunc, payload, NULL);
}
*/
import "C"
import (
"fmt"
"sync"
"sync/atomic"
"unsafe"
)
// --- Closure registry (separate from compile.go's registry) ---
var (
gradFuncs sync.Map
gradNextID atomic.Uintptr
)
//export goGradFunc
func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
id := uintptr(payload)
fnI, ok := gradFuncs.Load(id)
if !ok {
return 1
}
fn := fnI.(func([]*Array) []*Array)
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("GRAD_INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
goInputs[i] = a
}
goOutputs := fn(goInputs)
// The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_).
// Create a valid vector, fill it, then copy to the output pointer.
tmp := C.mlx_vector_array_new()
for _, out := range goOutputs {
C.mlx_vector_array_append_value(tmp, out.ctx)
}
C.mlx_vector_array_set(outputs, tmp)
C.mlx_vector_array_free(tmp)
return 0
}
// newClosure registers a Go function as an MLX closure for autograd.
func newClosure(fn func([]*Array) []*Array) C.mlx_closure {
id := gradNextID.Add(1)
gradFuncs.Store(id, fn)
return C.new_grad_closure(unsafe.Pointer(id))
}
// --- VJP (Vector-Jacobian Product) — Reverse Mode ---
// VJP computes the vector-Jacobian product (reverse-mode autodiff).
// Given a function fn, input primals, and output cotangents (upstream gradients),
// returns (outputs, gradients) where gradients are w.r.t. the primals.
//
// This is the fundamental backward pass operation.
func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Pack primals into vector
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
// Pack cotangents into vector
cotangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(cotangentsVec)
for _, c := range cotangents {
C.mlx_vector_array_append_value(cotangentsVec, c.ctx)
}
// Call mlx_vjp
var outVec, vjpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
vjpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(vjpVec)
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: vjp failed")
}
outputs = vectorToArrays(outVec)
vjps = vectorToArrays(vjpVec)
return outputs, vjps, nil
}
// --- JVP (Jacobian-Vector Product) — Forward Mode ---
// JVP computes the Jacobian-vector product (forward-mode autodiff).
// Given a function fn, input primals, and input tangents (perturbation directions),
// returns (outputs, output_tangents).
//
// Useful for directional derivatives and Hessian-vector products.
func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
primalsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(primalsVec)
for _, p := range primals {
C.mlx_vector_array_append_value(primalsVec, p.ctx)
}
tangentsVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(tangentsVec)
for _, t := range tangents {
C.mlx_vector_array_append_value(tangentsVec, t.ctx)
}
var outVec, jvpVec C.mlx_vector_array
outVec = C.mlx_vector_array_new()
jvpVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
defer C.mlx_vector_array_free(jvpVec)
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: jvp failed")
}
outputs = vectorToArrays(outVec)
jvps = vectorToArrays(jvpVec)
return outputs, jvps, nil
}
// --- ValueAndGrad — Combined Forward + Backward ---
// GradFn is a function that computes both the loss value and gradients
// with respect to specified arguments. Call it with inputs to get
// (values, gradients).
type GradFn struct {
cls C.mlx_closure_value_and_grad
}
// ValueAndGrad creates a GradFn that computes both the function value and
// gradients with respect to the arguments at the given indices.
//
// The returned GradFn can be called repeatedly with different inputs.
// This is the primary API for training loops:
//
// lossFn := func(params []*Array) []*Array { return []*Array{loss} }
// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg
// values, grads := grad.Apply(params...)
func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
// Default: differentiate w.r.t. first argument
if len(argnums) == 0 {
argnums = []int{0}
}
cArgs := make([]C.int, len(argnums))
for i, a := range argnums {
cArgs[i] = C.int(a)
}
g := &GradFn{}
C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs)))
return g
}
// Apply calls the gradient function with the given inputs.
// Returns (values, gradients) where values are the function outputs
// and gradients are w.r.t. the arguments specified in ValueAndGrad.
func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) {
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
var valVec, gradVec C.mlx_vector_array
valVec = C.mlx_vector_array_new()
gradVec = C.mlx_vector_array_new()
defer C.mlx_vector_array_free(valVec)
defer C.mlx_vector_array_free(gradVec)
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
if rc != 0 {
checkError()
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
}
values = vectorToArrays(valVec)
grads = vectorToArrays(gradVec)
return values, grads, nil
}
// Free releases the underlying C closure.
func (g *GradFn) Free() {
if g.cls.ctx != nil {
C.mlx_closure_value_and_grad_free(g.cls)
g.cls.ctx = nil
}
}
// --- Checkpoint — Memory-Efficient Gradient Recomputation ---
// Checkpoint wraps a function so that during backward pass, intermediate
// activations are recomputed rather than stored. Trades compute for memory.
//
// Use this for memory-constrained training (large models on limited VRAM).
func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array {
Init()
closure := newClosure(fn)
defer C.mlx_closure_free(closure)
var checkpointed C.mlx_closure
C.mlx_checkpoint(&checkpointed, closure)
// Register the checkpointed closure so it stays alive
id := gradNextID.Add(1)
cpClosure := checkpointed // capture for the returned function
// Return a Go function that calls through the checkpointed closure
return func(inputs []*Array) []*Array {
_ = id // prevent GC of closure reference
inputVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(inputVec)
for _, in := range inputs {
C.mlx_vector_array_append_value(inputVec, in.ctx)
}
outVec := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(outVec)
C.mlx_closure_apply(&outVec, cpClosure, inputVec)
return vectorToArrays(outVec)
}
}
// --- Loss Functions ---
// CrossEntropyLoss computes cross-entropy loss between logits and integer targets.
// logits: [..., V] (raw model output, pre-softmax, last dim = vocab)
// targets: [...] (integer token IDs, same shape as logits minus last dim)
// Returns scalar loss averaged over all positions.
//
// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i]
func CrossEntropyLoss(logits, targets *Array) *Array {
Init()
// LogSumExp along last axis (vocab dimension) for numerical stability
lse := LogSumExp(logits, -1, false)
// Gather the logit at the target index: logits[..., target]
// targets needs to be [..., 1] for take_along_axis
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1) // back to [...] shape
// Per-position loss: logsumexp - logit_at_target
perPos := Subtract(lse, gathered)
// Mean over all positions
return MeanAll(perPos)
}
// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions.
// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore)
// Returns scalar loss averaged over masked positions only.
func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array {
Init()
lse := LogSumExp(logits, -1, false)
tgtExpanded := ExpandDims(targets, -1)
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
gathered = Squeeze(gathered, -1)
perPos := Subtract(lse, gathered) // [B, L]
// Apply mask and average over masked positions
masked := Mul(perPos, mask)
return Divide(SumAll(masked), SumAll(mask))
}
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
func MSELoss(predictions, targets *Array) *Array {
diff := Subtract(predictions, targets)
sq := Square(diff)
return MeanAll(sq)
}
// --- Helpers ---
// vectorToArrays extracts all arrays from an mlx_vector_array.
func vectorToArrays(vec C.mlx_vector_array) []*Array {
n := int(C.mlx_vector_array_size(vec))
out := make([]*Array, n)
for i := 0; i < n; i++ {
a := New("VEC_OUT")
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
out[i] = a
}
return out
}
// Log returns element-wise natural logarithm.
func Log(a *Array) *Array {
out := New("LOG", a)
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// SumAll reduces by summation over all elements, returning a scalar.
func SumAll(a *Array) *Array {
flat := Reshape(a, -1)
return Sum(flat, 0, false)
}
// MeanAll reduces by averaging over all elements, returning a scalar.
func MeanAll(a *Array) *Array {
flat := Reshape(a, -1)
return Mean(flat, 0, false)
}
// OnesLike creates an array of ones with the same shape and type as the input.
func OnesLike(a *Array) *Array {
out := New("ONES_LIKE", a)
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}

View file

@ -1,332 +0,0 @@
//go:build darwin && arm64
package mlx
import (
"math"
"testing"
)
func TestVJP_SimpleSquare(t *testing.T) {
// f(x) = x^2, df/dx = 2x
// At x=3: f(3)=9, df/dx=6
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
x := FromValue(float32(3.0))
cotangent := FromValue(float32(1.0)) // upstream grad = 1
outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(outputs[0], grads[0])
got := outputs[0].Float()
if math.Abs(got-9.0) > 1e-5 {
t.Errorf("output = %f, want 9.0", got)
}
grad := grads[0].Float()
if math.Abs(grad-6.0) > 1e-5 {
t.Errorf("grad = %f, want 6.0", grad)
}
}
func TestVJP_Addition(t *testing.T) {
// f(x, y) = x + y, df/dx = 1, df/dy = 1
fn := func(inputs []*Array) []*Array {
return []*Array{Add(inputs[0], inputs[1])}
}
x := FromValue(float32(2.0))
y := FromValue(float32(5.0))
cotangent := FromValue(float32(1.0))
_, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(grads...)
if math.Abs(grads[0].Float()-1.0) > 1e-5 {
t.Errorf("dx = %f, want 1.0", grads[0].Float())
}
if math.Abs(grads[1].Float()-1.0) > 1e-5 {
t.Errorf("dy = %f, want 1.0", grads[1].Float())
}
}
func TestVJP_MatmulGrad(t *testing.T) {
// f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W
// For W=[2,2], x=[2,1]: dL/dW = ones @ x^T
x := FromValues([]float32{1.0, 2.0}, 2, 1)
w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity
fn := func(inputs []*Array) []*Array {
result := Matmul(inputs[0], x)
return []*Array{SumAll(result)}
}
cotangent := FromValue(float32(1.0))
outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent})
if err != nil {
t.Fatalf("VJP failed: %v", err)
}
Materialize(outputs[0], grads[0])
// W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3
got := outputs[0].Float()
if math.Abs(got-3.0) > 1e-5 {
t.Errorf("output = %f, want 3.0", got)
}
// Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T
// = [[1,2],[1,2]]
gradFloats := grads[0].Floats()
expected := []float32{1.0, 2.0, 1.0, 2.0}
for i, exp := range expected {
if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 {
t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp)
}
}
}
func TestJVP_SimpleSquare(t *testing.T) {
// f(x) = x^2, JVP with tangent v: df = 2x * v
// At x=3, v=1: df = 6
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
x := FromValue(float32(3.0))
tangent := FromValue(float32(1.0))
outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent})
if err != nil {
t.Fatalf("JVP failed: %v", err)
}
Materialize(outputs[0], jvps[0])
got := outputs[0].Float()
if math.Abs(got-9.0) > 1e-5 {
t.Errorf("output = %f, want 9.0", got)
}
jvp := jvps[0].Float()
if math.Abs(jvp-6.0) > 1e-5 {
t.Errorf("jvp = %f, want 6.0", jvp)
}
}
func TestValueAndGrad_Quadratic(t *testing.T) {
// f(x) = x^2 + 2x + 1 = (x+1)^2
// f'(x) = 2x + 2
// At x=3: f(3) = 16, f'(3) = 8
fn := func(inputs []*Array) []*Array {
x := inputs[0]
x2 := Mul(x, x)
two_x := MulScalar(x, 2.0)
one := FromValue(float32(1.0))
return []*Array{Add(Add(x2, two_x), one)}
}
grad := ValueAndGrad(fn, 0)
defer grad.Free()
x := FromValue(float32(3.0))
values, grads, err := grad.Apply(x)
if err != nil {
t.Fatalf("ValueAndGrad failed: %v", err)
}
Materialize(values[0], grads[0])
val := values[0].Float()
if math.Abs(val-16.0) > 1e-5 {
t.Errorf("value = %f, want 16.0", val)
}
g := grads[0].Float()
if math.Abs(g-8.0) > 1e-5 {
t.Errorf("grad = %f, want 8.0", g)
}
}
func TestValueAndGrad_MultiArg(t *testing.T) {
// f(x, y) = x*y, df/dx = y, df/dy = x
// At x=3, y=4: f=12, dx=4, dy=3
fn := func(inputs []*Array) []*Array {
return []*Array{Mul(inputs[0], inputs[1])}
}
// Differentiate w.r.t. both arguments
grad := ValueAndGrad(fn, 0, 1)
defer grad.Free()
x := FromValue(float32(3.0))
y := FromValue(float32(4.0))
values, grads, err := grad.Apply(x, y)
if err != nil {
t.Fatalf("ValueAndGrad failed: %v", err)
}
Materialize(values[0], grads[0], grads[1])
val := values[0].Float()
if math.Abs(val-12.0) > 1e-5 {
t.Errorf("value = %f, want 12.0", val)
}
dx := grads[0].Float()
if math.Abs(dx-4.0) > 1e-5 {
t.Errorf("dx = %f, want 4.0 (y)", dx)
}
dy := grads[1].Float()
if math.Abs(dy-3.0) > 1e-5 {
t.Errorf("dy = %f, want 3.0 (x)", dy)
}
}
func TestValueAndGrad_Reusable(t *testing.T) {
// Verify GradFn can be called multiple times
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)} // x^2, grad = 2x
}
grad := ValueAndGrad(fn)
defer grad.Free()
for _, tc := range []struct {
x float32
want float64 // expected gradient
}{
{2.0, 4.0},
{5.0, 10.0},
{-3.0, -6.0},
{0.0, 0.0},
} {
x := FromValue(tc.x)
_, grads, err := grad.Apply(x)
if err != nil {
t.Fatalf("Apply failed for x=%f: %v", tc.x, err)
}
Materialize(grads[0])
g := grads[0].Float()
if math.Abs(g-tc.want) > 1e-5 {
t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want)
}
}
}
func TestCrossEntropyLoss(t *testing.T) {
// Simple 3-class classification
// logits = [1.0, 2.0, 3.0], target = 2 (class index)
// Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1)
// = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076
// loss = 3.4076 - 3.0 = 0.4076
logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3]
targets := FromValues([]int32{2}, 1) // [1]
loss := CrossEntropyLoss(logits, targets)
Materialize(loss)
got := loss.Float()
expected := 0.4076
if math.Abs(got-expected) > 0.01 {
t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected)
}
}
func TestMSELoss(t *testing.T) {
pred := FromValues([]float32{1.0, 2.0, 3.0}, 3)
target := FromValues([]float32{1.5, 2.5, 3.5}, 3)
loss := MSELoss(pred, target)
Materialize(loss)
// MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25
got := loss.Float()
if math.Abs(got-0.25) > 1e-5 {
t.Errorf("MSELoss = %f, want 0.25", got)
}
}
func TestLogSumExp(t *testing.T) {
// logsumexp([1, 2, 3]) along axis -1
a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3)
result := LogSumExp(a, -1, false)
Materialize(result)
// = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076
got := result.Float()
expected := 3.4076
if math.Abs(got-expected) > 0.01 {
t.Errorf("LogSumExp = %f, want ~%f", got, expected)
}
}
func TestOnesLike(t *testing.T) {
a := FromValues([]float32{1.0, 2.0, 3.0}, 3)
ones := OnesLike(a)
Materialize(ones)
floats := ones.Floats()
for i, f := range floats {
if f != 1.0 {
t.Errorf("OnesLike[%d] = %f, want 1.0", i, f)
}
}
}
func TestCheckpoint(t *testing.T) {
// Checkpoint should produce the same result as the original function
fn := func(inputs []*Array) []*Array {
x := inputs[0]
return []*Array{Mul(x, x)}
}
cpFn := Checkpoint(fn)
x := FromValue(float32(5.0))
result := cpFn([]*Array{x})
Materialize(result[0])
got := result[0].Float()
if math.Abs(got-25.0) > 1e-5 {
t.Errorf("Checkpoint result = %f, want 25.0", got)
}
}
func TestSumAll(t *testing.T) {
a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2)
result := SumAll(a)
Materialize(result)
got := result.Float()
if math.Abs(got-10.0) > 1e-5 {
t.Errorf("SumAll = %f, want 10.0", got)
}
}
func TestMeanAll(t *testing.T) {
a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2)
result := MeanAll(a)
Materialize(result)
got := result.Float()
if math.Abs(got-5.0) > 1e-5 {
t.Errorf("MeanAll = %f, want 5.0", got)
}
}

View file

@ -1,63 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"iter"
"runtime"
"unsafe"
)
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
Init()
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
arr := &Array{ctx: value, name: name}
runtime.SetFinalizer(arr, finalizeArray)
if !yield(name, arr) {
break
}
}
}
}
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
func LoadAllSafetensors(path string) map[string]*Array {
tensors := make(map[string]*Array)
for name, arr := range LoadSafetensors(path) {
tensors[name] = arr
}
return tensors
}

View file

@ -1,237 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"fmt"
"math"
"sort"
"unsafe"
)
// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters.
//
// Forward: base(x) + scale * (x @ A^T) @ B^T
//
// A is [rank, in_features] — initialised with Kaiming normal
// B is [out_features, rank] — initialised to zero
// Scale = alpha / rank
//
// Only A and B are trainable. The base weights stay frozen.
type LoRALinear struct {
Base *Linear // Frozen base weights (may be quantized)
A *Array // [rank, in_features] — trainable
B *Array // [out_features, rank] — trainable
Scale float32 // alpha / rank
Rank int
Alpha float32
}
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
// rank: decomposition rank (typically 4, 8, or 16)
// alpha: scaling factor (typically 2*rank)
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
// Determine dimensions from the base weight.
// Weight shape is [out_features, in_features] for standard linear,
// or quantized shape which we handle via the base layer.
var inFeatures, outFeatures int32
if base.Scales != nil {
// Quantized: weight is packed. Compute dims from scales.
// scales shape is [out_features, in_features / group_size]
scaleShape := base.Scales.Shape()
outFeatures = scaleShape[0]
inFeatures = scaleShape[1] * int32(base.GroupSize)
} else {
wShape := base.Weight.Shape()
outFeatures = wShape[0]
inFeatures = wShape[1]
}
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
// B: zero initialisation — LoRA starts as identity (no change to base)
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
Materialize(a, b)
return &LoRALinear{
Base: base,
A: a,
B: b,
Scale: alpha / float32(rank),
Rank: rank,
Alpha: alpha,
}
}
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
// Calls baseForward on the underlying Linear to avoid infinite recursion
// when the Linear's Forward method dispatches through LoRA.
func (l *LoRALinear) Forward(x *Array) *Array {
baseOut := l.Base.baseForward(x)
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
loraOut := Matmul(x, Transpose(l.A))
loraOut = Matmul(loraOut, Transpose(l.B))
loraOut = MulScalar(loraOut, l.Scale)
return Add(baseOut, loraOut)
}
// TrainableParams returns the LoRA A and B arrays for gradient computation.
func (l *LoRALinear) TrainableParams() []*Array {
return []*Array{l.A, l.B}
}
// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step).
func (l *LoRALinear) SetParams(a, b *Array) {
l.A = a
l.B = b
}
// ParamCount returns the number of trainable parameters.
func (l *LoRALinear) ParamCount() int {
aShape := l.A.Shape()
bShape := l.B.Shape()
return int(aShape[0]*aShape[1] + bShape[0]*bShape[1])
}
// --- LoRA Application to Models ---
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
type LoRAConfig struct {
Rank int // Decomposition rank (default 8)
Alpha float32 // Scaling factor (default 16)
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
}
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
func DefaultLoRAConfig() LoRAConfig {
return LoRAConfig{
Rank: 8,
Alpha: 16,
TargetKeys: []string{"q_proj", "v_proj"},
}
}
// LoRAAdapter holds all LoRA layers applied to a model.
type LoRAAdapter struct {
Layers map[string]*LoRALinear // keyed by weight path prefix
Config LoRAConfig
}
// TotalParams returns the total number of trainable parameters across all LoRA layers.
func (a *LoRAAdapter) TotalParams() int {
total := 0
for _, l := range a.Layers {
total += l.ParamCount()
}
return total
}
// SortedNames returns layer names in deterministic sorted order.
func (a *LoRAAdapter) SortedNames() []string {
names := make([]string, 0, len(a.Layers))
for name := range a.Layers {
names = append(names, name)
}
sort.Strings(names)
return names
}
// AllTrainableParams returns all trainable arrays (A and B from every layer),
// in a deterministic order sorted by layer name.
func (a *LoRAAdapter) AllTrainableParams() []*Array {
names := a.SortedNames()
params := make([]*Array, 0, len(names)*2)
for _, name := range names {
l := a.Layers[name]
params = append(params, l.A, l.B)
}
return params
}
// SetAllParams updates all LoRA A and B arrays from a flat slice,
// in the same deterministic order as AllTrainableParams.
func (a *LoRAAdapter) SetAllParams(params []*Array) {
names := a.SortedNames()
for i, name := range names {
l := a.Layers[name]
l.A = params[i*2]
l.B = params[i*2+1]
}
}
// Save writes the LoRA adapter weights to a safetensors file.
// Only saves the A and B matrices — not the frozen base weights.
func (a *LoRAAdapter) Save(path string) error {
weights := make(map[string]*Array)
for name, l := range a.Layers {
weights[name+".lora_a"] = l.A
weights[name+".lora_b"] = l.B
}
return SaveSafetensors(path, weights)
}
// --- Random Normal ---
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
Init()
out := New("RANDOM_NORMAL")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_normal(
&out.ctx,
&cShape[0], C.size_t(len(cShape)),
C.mlx_dtype(dtype),
C.float(mean), C.float(stddev),
key, // null key = use default RNG
DefaultStream().ctx,
)
return out
}
// --- SaveSafetensors ---
// SaveSafetensors saves a map of named arrays to a .safetensors file.
func SaveSafetensors(path string, weights map[string]*Array) error {
Init()
// Build the map
cMap := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(cMap)
for name, arr := range weights {
cName := C.CString(name)
C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx)
C.free(unsafe.Pointer(cName))
}
// Empty metadata
cMeta := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(cMeta)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
if rc != 0 {
checkError()
return fmt.Errorf("mlx: save safetensors failed: %s", path)
}
return nil
}

View file

@ -1,316 +0,0 @@
//go:build darwin && arm64
package mlx
import (
"math"
"os"
"testing"
)
func TestNewLoRALinear(t *testing.T) {
// Create a simple base linear layer: [4, 8] weight
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0) // rank=4, alpha=8
// Check dimensions
aShape := lora.A.Shape()
bShape := lora.B.Shape()
if aShape[0] != 4 || aShape[1] != 8 {
t.Errorf("A shape = %v, want [4, 8]", aShape)
}
if bShape[0] != 4 || bShape[1] != 4 {
t.Errorf("B shape = %v, want [4, 4]", bShape)
}
// Scale should be alpha/rank = 8/4 = 2
if math.Abs(float64(lora.Scale)-2.0) > 1e-5 {
t.Errorf("Scale = %f, want 2.0", lora.Scale)
}
// B should be all zeros (LoRA starts as identity)
Materialize(lora.B)
bFloats := lora.B.Floats()
for i, v := range bFloats {
if v != 0 {
t.Errorf("B[%d] = %f, want 0", i, v)
}
}
}
func TestLoRALinear_ForwardMatchesBase(t *testing.T) {
// With B=0, LoRA forward should equal base forward
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0)
// Random input [1, 3, 8]
x := RandomNormal(0, 1, []int32{1, 3, 8}, DTypeFloat32)
Materialize(x)
baseOut := base.Forward(x)
loraOut := lora.Forward(x)
Materialize(baseOut, loraOut)
// Should be identical since B is zero
baseFloats := baseOut.Floats()
loraFloats := loraOut.Floats()
if len(baseFloats) != len(loraFloats) {
t.Fatalf("output sizes differ: base=%d, lora=%d", len(baseFloats), len(loraFloats))
}
for i := range baseFloats {
diff := math.Abs(float64(baseFloats[i] - loraFloats[i]))
if diff > 1e-4 {
t.Errorf("output[%d] differs: base=%f, lora=%f", i, baseFloats[i], loraFloats[i])
}
}
}
func TestLoRALinear_ForwardWithAdapter(t *testing.T) {
// Set A and B to known values and verify output changes
w := Zeros([]int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 2, 4.0) // rank=2, alpha=4, scale=2
// Set A to identity-like: [[1,0,0,...], [0,1,0,...]]
a := Zeros([]int32{2, 8}, DTypeFloat32)
// Set B to ones: [[1,1], [1,1], [1,1], [1,1]]
b := FromValues([]float32{
1, 1,
1, 1,
1, 1,
1, 1,
}, 4, 2)
Materialize(a, b)
lora.A = a
lora.B = b
// With base=0, A=0, output should also be 0 (scale * x@0@B^T = 0)
x := FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 8)
result := lora.Forward(x)
Materialize(result)
// base(x) = 0 (zero weights), lora = scale * (x @ A^T) @ B^T
// A is zeros, so x @ A^T = [0, 0], then @ B^T = [0,0,0,0]
for _, v := range result.Floats() {
if v != 0 {
t.Errorf("expected 0 with zero A, got %f", v)
}
}
}
func TestLoRALinear_ParamCount(t *testing.T) {
w := RandomNormal(0, 0.01, []int32{64, 128}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 8, 16.0) // rank=8
// A: [8, 128] = 1024, B: [64, 8] = 512, total = 1536
expected := 8*128 + 64*8
if lora.ParamCount() != expected {
t.Errorf("ParamCount = %d, want %d", lora.ParamCount(), expected)
}
}
func TestLoRALinear_TrainableParams(t *testing.T) {
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0)
params := lora.TrainableParams()
if len(params) != 2 {
t.Fatalf("TrainableParams returned %d arrays, want 2", len(params))
}
// First is A, second is B
if params[0].Shape()[0] != 4 || params[0].Shape()[1] != 8 {
t.Errorf("param[0] (A) shape = %v, want [4, 8]", params[0].Shape())
}
if params[1].Shape()[0] != 4 || params[1].Shape()[1] != 4 {
t.Errorf("param[1] (B) shape = %v, want [4, 4]", params[1].Shape())
}
}
func TestLoRALinear_GradientFlows(t *testing.T) {
// Verify that gradients flow through the LoRA path
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0)
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
Materialize(x)
// Loss function: sum of LoRA output (differentiating w.r.t. A and B)
lossFn := func(inputs []*Array) []*Array {
lora.A = inputs[0]
lora.B = inputs[1]
out := lora.Forward(x)
return []*Array{SumAll(out)}
}
grad := ValueAndGrad(lossFn, 0, 1) // grad w.r.t. A and B
defer grad.Free()
values, grads, err := grad.Apply(lora.A, lora.B)
if err != nil {
t.Fatalf("ValueAndGrad failed: %v", err)
}
Materialize(append(values, grads...)...)
// Loss should be a scalar
loss := values[0].Float()
t.Logf("loss = %f", loss)
// Gradients should be non-zero (A has random init, B is zero but gets grad)
gradA := grads[0]
gradB := grads[1]
aGradFloats := gradA.Floats()
bGradFloats := gradB.Floats()
hasNonZeroA := false
for _, v := range aGradFloats {
if v != 0 {
hasNonZeroA = true
break
}
}
hasNonZeroB := false
for _, v := range bGradFloats {
if v != 0 {
hasNonZeroB = true
break
}
}
// A gradient might be zero if B is zero (since dL/dA depends on B)
// But B gradient should be non-zero since A is random
if !hasNonZeroB {
t.Error("gradient for B is all zeros — gradients not flowing")
}
t.Logf("gradA has non-zero: %v, gradB has non-zero: %v", hasNonZeroA, hasNonZeroB)
}
func TestRandomNormal(t *testing.T) {
arr := RandomNormal(0, 1, []int32{100}, DTypeFloat32)
Materialize(arr)
floats := arr.Floats()
if len(floats) != 100 {
t.Fatalf("RandomNormal returned %d elements, want 100", len(floats))
}
// Check rough statistics: mean should be near 0, values should have spread
var sum float64
for _, f := range floats {
sum += float64(f)
}
mean := sum / 100
if math.Abs(mean) > 0.5 { // generous tolerance for 100 samples
t.Errorf("mean = %f, expected near 0", mean)
}
}
func TestSaveSafetensors(t *testing.T) {
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
b := FromValues([]float32{5, 6, 7, 8, 9, 10}, 3, 2)
Materialize(a, b)
path := t.TempDir() + "/test.safetensors"
err := SaveSafetensors(path, map[string]*Array{
"layer.lora_a": a,
"layer.lora_b": b,
})
if err != nil {
t.Fatalf("SaveSafetensors failed: %v", err)
}
// Verify file exists
info, err := os.Stat(path)
if err != nil {
t.Fatalf("saved file not found: %v", err)
}
if info.Size() == 0 {
t.Error("saved file is empty")
}
// Load it back
loaded := LoadAllSafetensors(path)
Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"])
aLoaded := loaded["layer.lora_a"].Floats()
bLoaded := loaded["layer.lora_b"].Floats()
expectedA := []float32{1, 2, 3, 4}
expectedB := []float32{5, 6, 7, 8, 9, 10}
for i, v := range expectedA {
if aLoaded[i] != v {
t.Errorf("loaded A[%d] = %f, want %f", i, aLoaded[i], v)
}
}
for i, v := range expectedB {
if bLoaded[i] != v {
t.Errorf("loaded B[%d] = %f, want %f", i, bLoaded[i], v)
}
}
}
func TestLoRAAdapter_Save(t *testing.T) {
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
adapter := &LoRAAdapter{
Layers: map[string]*LoRALinear{
"model.layers.0.self_attn.q_proj": NewLoRALinear(base, 4, 8.0),
},
Config: DefaultLoRAConfig(),
}
path := t.TempDir() + "/adapter.safetensors"
err := adapter.Save(path)
if err != nil {
t.Fatalf("Adapter.Save failed: %v", err)
}
// Load and verify
loaded := LoadAllSafetensors(path)
aKey := "model.layers.0.self_attn.q_proj.lora_a"
bKey := "model.layers.0.self_attn.q_proj.lora_b"
if _, ok := loaded[aKey]; !ok {
t.Errorf("missing key %s in saved adapter", aKey)
}
if _, ok := loaded[bKey]; !ok {
t.Errorf("missing key %s in saved adapter", bKey)
}
}
func TestDefaultLoRAConfig(t *testing.T) {
cfg := DefaultLoRAConfig()
if cfg.Rank != 8 {
t.Errorf("Rank = %d, want 8", cfg.Rank)
}
if cfg.Alpha != 16 {
t.Errorf("Alpha = %f, want 16", cfg.Alpha)
}
if len(cfg.TargetKeys) != 2 {
t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys)
}
}

View file

@ -1,115 +0,0 @@
//go:build darwin && arm64
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
//
// Build mlx-c before use:
//
// cd pkg/mlx && go generate ./...
//
// Build (MLX is auto-enabled on darwin/arm64):
//
// go build -o core .
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
/*
#cgo CXXFLAGS: -std=c++17
#cgo CFLAGS: -mmacosx-version-min=26.0
#cgo CPPFLAGS: -I${SRCDIR}/dist/include
#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
#include <stdio.h>
#include <stdlib.h>
#include "mlx/c/mlx.h"
static const char *last_mlx_error = NULL;
static void mlx_go_error_handler(const char *msg, void *data) {
fprintf(stderr, "MLX ERROR: %s\n", msg);
last_mlx_error = msg;
}
static void set_error_handler() {
mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL);
}
static const char* get_last_error() {
return last_mlx_error;
}
*/
import "C"
import (
"log/slog"
"sync"
)
var initOnce sync.Once
// Init sets up the MLX error handler. Called automatically on first use.
func Init() {
initOnce.Do(func() {
C.set_error_handler()
slog.Debug("mlx: initialized with Metal backend")
})
}
// checkError logs the last MLX error if any occurred.
func checkError() {
if msg := C.get_last_error(); msg != nil {
slog.Error("mlx", "error", C.GoString(msg))
}
}
// Materialize synchronously evaluates arrays, computing their values on the GPU.
// This is the MLX equivalent of forcing lazy computation to complete.
func Materialize(outputs ...*Array) {
doMaterialize(outputs, false)
}
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
func MaterializeAsync(outputs ...*Array) {
doMaterialize(outputs, true)
}
func doMaterialize(outputs []*Array, async bool) {
Init()
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
// Collect gathers all valid arrays from a variadic list for batch Materialize.
func Collect(arrays ...*Array) []*Array {
var out []*Array
for _, a := range arrays {
if a != nil && a.Valid() {
out = append(out, a)
}
}
return out
}
// MetalAvailable reports whether Metal GPU is available.
func MetalAvailable() bool {
Init()
var available C.bool
C.mlx_metal_is_available(&available)
return bool(available)
}

View file

@ -1,10 +0,0 @@
//go:build !(darwin && arm64)
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
// This stub file is used on non-darwin/non-arm64 platforms or when the
// mlx build tag is not set. All operations report MLX as unavailable.
package mlx
// MetalAvailable reports whether Metal GPU is available.
// Always returns false on non-Apple Silicon platforms.
func MetalAvailable() bool { return false }

View file

@ -1,448 +0,0 @@
//go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference.
package model
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
)
// TextConfig holds Gemma 3 text model configuration.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
}
// GemmaModel is the Gemma 3 text model.
type GemmaModel struct {
EmbedTokens *mlx.Embedding
Layers []*DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear // Tied to EmbedTokens
// Precomputed (1 + weight) for Gemma-style RMSNorm
NormScaled *mlx.Array
Tok *tokenizer.Tokenizer
Cfg *TextConfig
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *mlx.RMSNormModule
Attention *Attention
PostAttnNorm *mlx.RMSNormModule
PreFFNorm *mlx.RMSNormModule
MLP *MLP
PostFFNorm *mlx.RMSNormModule
// Precomputed scaled weights
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network.
type MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
}
// compiledGELU is a singleton for the compiled GELU function.
var compiledGELU *mlx.CompiledFunc
func getCompiledGELU() *mlx.CompiledFunc {
if compiledGELU == nil {
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApprox(inputs[0])}
}, true)
}
return compiledGELU
}
// geluApprox computes GELU using the tanh approximation:
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApprox(x *mlx.Array) *mlx.Array {
const sqrt2OverPi = 0.7978845608028654
const coeff = 0.044715
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
scaled := mlx.MulScalar(inner, sqrt2OverPi)
t := mlx.Tanh(scaled)
onePlusT := mlx.AddScalar(t, 1.0)
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
}
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
func parseConfig(data []byte) (*TextConfig, error) {
// Try parsing text_config from multimodal wrapper
var wrapper struct {
TextConfig TextConfig `json:"text_config"`
ModelType string `json:"model_type"`
Quantization *QuantizationConfig `json:"quantization"`
}
if err := json.Unmarshal(data, &wrapper); err != nil {
return nil, err
}
cfg := wrapper.TextConfig
// If text_config was empty, try top-level
if cfg.NumHiddenLayers == 0 {
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
}
// Quantization is always top-level
cfg.Quantization = wrapper.Quantization
// Compute scale (head_dim may be inferred later from weights if not in config)
if cfg.HeadDim > 0 {
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
}
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.SlidingWindowPattern == 0 {
cfg.SlidingWindowPattern = 6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 262208 // Gemma 3 default
}
return &cfg, nil
}
// LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load config: %w", err)
}
cfg, err := parseConfig(data)
if err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err)
}
// Load tokenizer
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
weights[name] = arr
}
}
// Helper to resolve weight with language_model. prefix fallback
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Infer head_dim from q_proj weight shape when not in config.
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
if cfg.HeadDim == 0 {
qWeight := w("model.layers.0.self_attn.q_proj.weight")
if qWeight != nil {
qShape := qWeight.Shape()
if len(qShape) > 0 {
cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim)
}
}
}
// Helper to create linear layer (quantized or dense)
q := cfg.Quantization
if q != nil {
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, nil)
}
// Create embedding (quantized or dense)
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &GemmaModel{
EmbedTokens: embed,
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
// Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
Attention: &Attention{
QProj: linear(prefix + ".self_attn.q_proj"),
KProj: linear(prefix + ".self_attn.k_proj"),
VProj: linear(prefix + ".self_attn.v_proj"),
OProj: linear(prefix + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
},
MLP: &MLP{
GateProj: linear(prefix + ".mlp.gate_proj"),
UpProj: linear(prefix + ".mlp.up_proj"),
DownProj: linear(prefix + ".mlp.down_proj"),
},
LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
}
}
// Output head — check for separate lm_head first, else tie to embeddings
lmHeadWeight := w("lm_head.weight")
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
m.Output = m.EmbedTokens.AsLinear()
}
// Materialize all weights
var allArrays []*mlx.Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeScaledWeights(m)
return m, nil
}
func precomputeScaledWeights(m *GemmaModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Materialize(scaled...)
}
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false
}
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass.
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
}
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// GQA: repeat K/V heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer caches for generation.
func (m *GemmaModel) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// NumLayers returns the number of transformer layers.
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *GemmaModel) ModelType() string { return "gemma3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
case "k_proj":
proj = layer.Attention.KProj
case "v_proj":
proj = layer.Attention.VProj
case "o_proj":
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}
}
}
return adapter
}

View file

@ -1,78 +0,0 @@
//go:build darwin && arm64
// Package model provides transformer model architectures for MLX inference.
package model
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
)
// Model is the common interface for all transformer model architectures.
type Model interface {
// Forward runs the model forward pass on token IDs with KV caches.
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
// NewCache creates per-layer KV caches for generation.
NewCache() []cache.Cache
// NumLayers returns the number of transformer layers.
NumLayers() int
// Tokenizer returns the model's tokenizer.
Tokenizer() *tokenizer.Tokenizer
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
ModelType() string
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
// Returns the adapter which holds references to all LoRA layers.
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
}
// QuantizationConfig holds quantization parameters from config.json.
type QuantizationConfig struct {
GroupSize int `json:"group_size"`
Bits int `json:"bits"`
}
// resolveWeight looks up a weight with optional "language_model." prefix.
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
if w, ok := weights[name]; ok {
return w
}
if w, ok := weights["language_model."+name]; ok {
return w
}
return nil
}
// LoadModel auto-detects the model architecture from config.json and loads it.
func LoadModel(modelPath string) (Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("model: load config: %w", err)
}
var probe struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("model: parse model_type: %w", err)
}
switch probe.ModelType {
case "qwen3":
return LoadQwen3(modelPath)
case "gemma3", "gemma2":
return LoadGemma3(modelPath)
default:
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
}
}

View file

@ -1,337 +0,0 @@
//go:build darwin && arm64
package model
import (
"encoding/json"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/go-ai/mlx"
"forge.lthn.ai/core/go-ai/mlx/cache"
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
)
// Qwen3Config holds Qwen 3 model configuration.
type Qwen3Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
Quantization *QuantizationConfig `json:"-"`
Scale float32 `json:"-"` // 1/sqrt(head_dim)
}
// Qwen3Model is the Qwen 3 text model.
type Qwen3Model struct {
EmbedTokens *mlx.Embedding
Layers []*Qwen3DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear
Tok *tokenizer.Tokenizer
Cfg *Qwen3Config
}
// Qwen3DecoderLayer is a single transformer block.
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
type Qwen3DecoderLayer struct {
InputNorm *mlx.RMSNormModule // Pre-attention norm
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
Attention *Qwen3Attention
MLP *Qwen3MLP
}
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
type Qwen3Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
}
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
type Qwen3MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
}
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
var cfg Qwen3Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, err
}
// Top-level quantization
var wrapper struct {
Quantization *QuantizationConfig `json:"quantization"`
}
json.Unmarshal(data, &wrapper)
cfg.Quantization = wrapper.Quantization
// Compute scale
if cfg.HeadDim == 0 {
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Defaults
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.VocabSize == 0 {
cfg.VocabSize = 151936
}
return &cfg, nil
}
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load config: %w", err)
}
cfg, err := parseQwen3Config(data)
if err != nil {
return nil, fmt.Errorf("qwen3: parse config: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
weights[name] = arr
}
}
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
// Quantization setup
q := cfg.Quantization
if q != nil {
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
}
linear := func(prefix string) *mlx.Linear {
weight := w(prefix + ".weight")
scales := w(prefix + ".scales")
biases := w(prefix + ".biases")
bias := w(prefix + ".bias")
if scales != nil && q != nil {
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
}
return mlx.NewLinear(weight, bias)
}
// Embedding
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
embed.Scales = embedScales
embed.Biases = w("model.embed_tokens.biases")
embed.GroupSize = q.GroupSize
embed.Bits = q.Bits
}
m := &Qwen3Model{
EmbedTokens: embed,
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
Tok: tok,
Cfg: cfg,
}
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
p := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &Qwen3DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
Attention: &Qwen3Attention{
QProj: linear(p + ".self_attn.q_proj"),
KProj: linear(p + ".self_attn.k_proj"),
VProj: linear(p + ".self_attn.v_proj"),
OProj: linear(p + ".self_attn.o_proj"),
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
},
MLP: &Qwen3MLP{
GateProj: linear(p + ".mlp.gate_proj"),
UpProj: linear(p + ".mlp.up_proj"),
DownProj: linear(p + ".mlp.down_proj"),
},
}
}
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
lmHeadWeight := w("lm_head.weight")
if lmHeadWeight != nil {
lmHeadScales := w("lm_head.scales")
if lmHeadScales != nil && q != nil {
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
} else {
m.Output = mlx.NewLinear(lmHeadWeight, nil)
}
} else {
m.Output = m.EmbedTokens.AsLinear()
}
// Materialise all weights onto Metal
var allArrays []*mlx.Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
slog.Info("qwen3: model loaded",
"layers", cfg.NumHiddenLayers,
"hidden", cfg.HiddenSize,
"heads", cfg.NumAttentionHeads,
"kv_heads", cfg.NumKeyValueHeads,
"head_dim", cfg.HeadDim,
"vocab", cfg.VocabSize,
)
return m, nil
}
// Forward runs the Qwen 3 forward pass.
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
}
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, cfg)
h := mlx.Add(x, attnOut)
// Pre-MLP norm → MLP → residual add
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
return mlx.Add(h, mlpOut)
}
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K RMS normalization (Qwen 3 has this)
q = a.QNorm.Forward(q, cfg.RMSNormEps)
k = a.KNorm.Forward(k, cfg.RMSNormEps)
// RoPE — single theta for all layers (no sliding window)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
// Update KV cache
k, v = c.Update(k, v, int(L))
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
gate := mlx.SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
func (m *Qwen3Model) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
// NumLayers returns the number of transformer layers.
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
// ModelType returns the architecture identifier.
func (m *Qwen3Model) ModelType() string { return "qwen3" }
// ApplyLoRA wraps target projection layers with LoRA adapters.
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
adapter := &mlx.LoRAAdapter{
Layers: make(map[string]*mlx.LoRALinear),
Config: cfg,
}
for i, layer := range m.Layers {
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
for _, target := range cfg.TargetKeys {
var proj *mlx.Linear
switch target {
case "q_proj":
proj = layer.Attention.QProj
case "k_proj":
proj = layer.Attention.KProj
case "v_proj":
proj = layer.Attention.VProj
case "o_proj":
proj = layer.Attention.OProj
}
if proj != nil {
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
proj.LoRA = lora
adapter.Layers[prefix+"."+target] = lora
}
}
}
return adapter
}

115
mlx/nn.go
View file

@ -1,115 +0,0 @@
//go:build darwin && arm64
package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias.
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
// Set LoRA to inject a low-rank adapter (training only).
type Linear struct {
Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
Bias *Array `weight:"bias"`
GroupSize int
Bits int
LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it
}
// NewLinear creates a dense Linear layer with optional bias.
func NewLinear(weight, bias *Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// NewQuantizedLinear creates a quantized Linear layer.
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
return &Linear{
Weight: weight,
Scales: scales,
Biases: biases,
Bias: bias,
GroupSize: groupSize,
Bits: bits,
}
}
// Forward computes the linear transformation.
// If a LoRA adapter is attached, routes through it instead (base + low-rank delta).
// Uses QuantizedMatmul when quantization parameters are present.
func (l *Linear) Forward(x *Array) *Array {
if l.LoRA != nil {
return l.LoRA.Forward(x)
}
return l.baseForward(x)
}
// baseForward is the raw linear transformation without LoRA.
// Used internally by LoRALinear to avoid infinite recursion.
func (l *Linear) baseForward(x *Array) *Array {
var out *Array
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
} else {
out = Matmul(x, Transpose(l.Weight))
}
if l.Bias != nil && l.Bias.Valid() {
out = Add(out, l.Bias)
}
return out
}
// Embedding is a lookup table for token embeddings.
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
type Embedding struct {
Weight *Array `weight:"weight"`
Scales *Array `weight:"scales"`
Biases *Array `weight:"biases"`
GroupSize int
Bits int
}
// Forward looks up embeddings for the given token indices.
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales != nil {
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
return Take(w, indices, 0)
}
return Take(e.Weight, indices, 0)
}
// AsLinear returns a Linear layer using the embedding weights (for tied output).
func (e *Embedding) AsLinear() *Linear {
return &Linear{
Weight: e.Weight,
Scales: e.Scales,
Biases: e.Biases,
GroupSize: e.GroupSize,
Bits: e.Bits,
}
}
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
type RMSNormModule struct {
Weight *Array `weight:"weight"`
}
// Forward applies RMS normalization.
func (r *RMSNormModule) Forward(x *Array, eps float32) *Array {
return RMSNorm(x, r.Weight, eps)
}
// RepeatKV repeats key/value heads for grouped-query attention.
// Input shape: [B, num_kv_heads, L, D]
// Output shape: [B, num_kv_heads * factor, L, D]
func RepeatKV(x *Array, factor int32) *Array {
if factor <= 1 {
return x
}
shape := x.Shape()
B, H, L, D := shape[0], shape[1], shape[2], shape[3]
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
expanded := ExpandDims(x, 2)
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
return Reshape(expanded, B, H*factor, L, D)
}

View file

@ -1,353 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import "unsafe"
// --- Element-wise arithmetic ---
// Add returns element-wise a + b.
func Add(a, b *Array) *Array {
out := New("ADD", a, b)
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// AddScalar returns a + scalar (broadcast).
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Add(a, scalar)
}
// Mul returns element-wise a * b.
func Mul(a, b *Array) *Array {
out := New("MUL", a, b)
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// MulScalar returns a * scalar (broadcast).
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Mul(a, scalar)
}
// Divide returns element-wise a / b.
func Divide(a, b *Array) *Array {
out := New("DIV", a, b)
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Subtract returns element-wise a - b.
func Subtract(a, b *Array) *Array {
out := New("SUB", a, b)
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Negative returns element-wise -a.
func Negative(a *Array) *Array {
out := New("NEG", a)
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// --- Math functions ---
// Exp returns element-wise exp(a).
func Exp(a *Array) *Array {
out := New("EXP", a)
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Sigmoid returns element-wise 1/(1+exp(-a)).
func Sigmoid(a *Array) *Array {
out := New("SIGMOID", a)
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
func SiLU(a *Array) *Array {
return Mul(a, Sigmoid(a))
}
// Tanh returns element-wise tanh(a).
func Tanh(a *Array) *Array {
out := New("TANH", a)
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Sqrt returns element-wise sqrt(a).
func Sqrt(a *Array) *Array {
out := New("SQRT", a)
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Rsqrt returns element-wise 1/sqrt(a).
func Rsqrt(a *Array) *Array {
out := New("RSQRT", a)
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Reciprocal returns element-wise 1/a.
func Reciprocal(a *Array) *Array {
out := New("RECIPROCAL", a)
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Square returns element-wise a^2.
func Square(a *Array) *Array {
out := New("SQUARE", a)
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Power returns element-wise a^b.
func Power(a, b *Array) *Array {
out := New("POWER", a, b)
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Maximum returns element-wise max(a, b).
func Maximum(a, b *Array) *Array {
out := New("MAX", a, b)
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise min(a, b).
func Minimum(a, b *Array) *Array {
out := New("MIN", a, b)
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// --- Matrix operations ---
// Matmul returns the matrix product of a and b.
func Matmul(a, b *Array) *Array {
out := New("MATMUL", a, b)
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// QuantizedMatmul performs quantized matrix multiplication.
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
out := New("QMATMUL", x, w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("affine")
defer C.free(unsafe.Pointer(mode))
C.mlx_quantized_matmul(
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
C._Bool(transpose), gs, b, mode,
DefaultStream().ctx,
)
return out
}
// --- Reductions ---
// Softmax returns softmax along the last axis.
func Softmax(a *Array) *Array {
out := New("SOFTMAX", a)
axis := []C.int{C.int(-1)}
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
return out
}
// Argmax returns the index of the maximum value along an axis.
func Argmax(a *Array, axis int, keepDims bool) *Array {
out := New("ARGMAX", a)
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// TopK returns the top k values along the last axis.
func TopK(a *Array, k int) *Array {
out := New("TOPK", a)
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
return out
}
// Sum reduces by summation along the given axis.
func Sum(a *Array, axis int, keepDims bool) *Array {
out := New("SUM", a)
axes := []C.int{C.int(axis)}
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// Mean reduces by averaging along the given axis.
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN", a)
axes := []C.int{C.int(axis)}
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// --- Shape operations ---
// Reshape changes the shape of an array.
func Reshape(a *Array, shape ...int32) *Array {
out := New("RESHAPE", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
return out
}
// Transpose permutes dimensions. If no axes given, reverses all dims.
func Transpose(a *Array, axes ...int) *Array {
out := New("TRANSPOSE", a)
if len(axes) == 0 {
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
} else {
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
}
return out
}
// ExpandDims inserts a new axis at the given position.
func ExpandDims(a *Array, axis int) *Array {
out := New("EXPAND_DIMS", a)
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Squeeze removes dimensions of size 1.
func Squeeze(a *Array, axes ...int) *Array {
out := New("SQUEEZE", a)
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
return out
}
// Concatenate joins arrays along the given axis.
func Concatenate(arrays []*Array, axis int) *Array {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
inputs := make([]*Array, len(arrays))
for i, a := range arrays {
C.mlx_vector_array_append_value(vector, a.ctx)
inputs[i] = a
}
out := New("CONCAT", inputs...)
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
// BroadcastTo broadcasts an array to the given shape.
func BroadcastTo(a *Array, shape []int32) *Array {
out := New("BROADCAST", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
return out
}
// AsType casts an array to a different dtype.
func AsType(a *Array, dtype DType) *Array {
out := New("ASTYPE", a)
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
// AsStrided creates a view with custom strides.
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
out := New("AS_STRIDED", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
return out
}
// Take gathers elements from a along axis using indices.
func Take(a, indices *Array, axis int) *Array {
out := New("TAKE", a, indices)
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Where selects elements from a or b based on condition.
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Argpartition partially sorts and returns indices for top-k selection.
func Argpartition(a *Array, kth, axis int) *Array {
out := New("ARGPARTITION", a)
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
// Dequantize restores a quantized array to full precision.
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
out := New("DEQUANTIZE", w, scales, biases)
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
mode := C.CString("affine")
defer C.free(unsafe.Pointer(mode))
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
return out
}
// PutAlongAxis places values into array at indices along axis.
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", a, indices, values)
// Use scatter approach: src[indices] = values
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// TakeAlongAxis gathers elements from a along axis using indices.
// Unlike Take, this uses the same number of dimensions for indices and input.
func TakeAlongAxis(a, indices *Array, axis int) *Array {
out := New("TAKE_ALONG_AXIS", a, indices)
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// LogSumExp computes log(sum(exp(a))) along the given axis.
// Numerically stable reduction for cross-entropy loss.
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
out := New("LOGSUMEXP", a)
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}

View file

@ -1,106 +0,0 @@
//go:build darwin && arm64
package mlx
import "math"
// AdamW implements the AdamW optimiser (Adam with decoupled weight decay).
//
// Update rule per parameter:
//
// m = beta1 * m + (1 - beta1) * grad
// v = beta2 * v + (1 - beta2) * grad^2
// m_hat = m / (1 - beta1^t)
// v_hat = v / (1 - beta2^t)
// param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
type AdamW struct {
LR float64 // Learning rate (default 1e-5)
Beta1 float64 // First moment decay (default 0.9)
Beta2 float64 // Second moment decay (default 0.999)
Eps float64 // Numerical stability (default 1e-8)
WeightDecay float64 // Decoupled weight decay (default 0.01)
step int // Number of updates performed
m []*Array // First moment estimates (positional, parallel to params)
v []*Array // Second moment estimates (positional, parallel to params)
}
// NewAdamW creates an AdamW optimiser with default hyperparameters.
func NewAdamW(lr float64) *AdamW {
return &AdamW{
LR: lr,
Beta1: 0.9,
Beta2: 0.999,
Eps: 1e-8,
WeightDecay: 0.01,
}
}
// Step performs one optimisation step: updates params using gradients.
// params and grads must be parallel slices of the same length.
// Returns the updated parameter arrays (params are replaced in-place).
func (o *AdamW) Step(params []*Array, grads []*Array) []*Array {
o.step++
// Bias correction factors
bc1 := 1.0 - math.Pow(o.Beta1, float64(o.step))
bc2 := 1.0 - math.Pow(o.Beta2, float64(o.step))
updated := make([]*Array, len(params))
// Grow moment slices if needed (first call or param count increased)
for len(o.m) < len(params) {
o.m = append(o.m, nil)
o.v = append(o.v, nil)
}
for i, param := range params {
grad := grads[i]
// Initialise moments on first use
if o.m[i] == nil {
shape := param.Shape()
o.m[i] = Zeros(shape, param.Dtype())
o.v[i] = Zeros(shape, param.Dtype())
}
// m = beta1 * m + (1 - beta1) * grad
m := Add(
MulScalar(o.m[i], float32(o.Beta1)),
MulScalar(grad, float32(1.0-o.Beta1)),
)
// v = beta2 * v + (1 - beta2) * grad^2
v := Add(
MulScalar(o.v[i], float32(o.Beta2)),
MulScalar(Square(grad), float32(1.0-o.Beta2)),
)
// Bias-corrected estimates
mHat := MulScalar(m, float32(1.0/bc1))
vHat := MulScalar(v, float32(1.0/bc2))
// Weight decay: param = param * (1 - lr * weight_decay)
decayed := MulScalar(param, float32(1.0-o.LR*o.WeightDecay))
// Update: param = decayed - lr * m_hat / (sqrt(v_hat) + eps)
denom := AddScalar(Sqrt(vHat), float32(o.Eps))
step := MulScalar(Divide(mHat, denom), float32(o.LR))
newParam := Subtract(decayed, step)
// Store updated moments
o.m[i] = m
o.v[i] = v
updated[i] = newParam
}
return updated
}
// Reset clears the optimiser state (moments and step counter).
func (o *AdamW) Reset() {
o.step = 0
o.m = nil
o.v = nil
}

View file

@ -1,175 +0,0 @@
//go:build darwin && arm64
package mlx
import (
"math"
"testing"
)
func TestAdamW_BasicStep(t *testing.T) {
// Simple test: minimise f(x) = x^2, starting at x=10
x := FromValue(float32(10.0))
Materialize(x)
opt := NewAdamW(0.1)
for i := 0; i < 300; i++ {
// Gradient of x^2 is 2x
lossFn := func(inputs []*Array) []*Array {
p := inputs[0]
return []*Array{Mul(p, p)}
}
grad := ValueAndGrad(lossFn)
_, grads, err := grad.Apply(x)
grad.Free()
if err != nil {
t.Fatalf("step %d: grad failed: %v", i, err)
}
updated := opt.Step([]*Array{x}, grads)
x = updated[0]
Materialize(x)
}
final := x.Float()
if math.Abs(final) > 0.5 {
t.Errorf("after 300 steps, x = %f, want near 0", final)
}
t.Logf("final x = %f (started at 10.0)", final)
}
func TestAdamW_MultiParam(t *testing.T) {
// Minimise f(x, y) = x^2 + y^2
x := FromValue(float32(5.0))
y := FromValue(float32(-3.0))
Materialize(x, y)
opt := NewAdamW(0.1)
for i := 0; i < 100; i++ {
lossFn := func(inputs []*Array) []*Array {
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
}
grad := ValueAndGrad(lossFn, 0, 1)
_, grads, err := grad.Apply(x, y)
grad.Free()
if err != nil {
t.Fatalf("step %d failed: %v", i, err)
}
updated := opt.Step([]*Array{x, y}, grads)
x = updated[0]
y = updated[1]
Materialize(x, y)
}
xFinal := x.Float()
yFinal := y.Float()
if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 {
t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal)
}
t.Logf("final x=%f, y=%f", xFinal, yFinal)
}
func TestAdamW_WeightDecay(t *testing.T) {
// With large weight decay and zero gradient, param should decay toward 0
x := FromValue(float32(10.0))
Materialize(x)
opt := NewAdamW(0.01)
opt.WeightDecay = 0.5 // aggressive decay
zeroGrad := FromValue(float32(0.0))
Materialize(zeroGrad)
for i := 0; i < 10; i++ {
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
x = updated[0]
Materialize(x)
}
final := x.Float()
if final >= 10.0 {
t.Errorf("x = %f, should have decayed from 10.0", final)
}
if final <= 0 {
t.Errorf("x = %f, decayed too much", final)
}
t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final)
}
func TestAdamW_Reset(t *testing.T) {
opt := NewAdamW(0.01)
x := FromValue(float32(5.0))
grad := FromValue(float32(1.0))
Materialize(x, grad)
opt.Step([]*Array{x}, []*Array{grad})
if opt.step != 1 {
t.Errorf("step = %d, want 1", opt.step)
}
opt.Reset()
if opt.step != 0 {
t.Errorf("after reset, step = %d, want 0", opt.step)
}
if opt.m != nil {
t.Error("after reset, moments should be nil")
}
}
func TestAdamW_WithLoRA(t *testing.T) {
// End-to-end: create LoRA layer, compute gradients, update with AdamW
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
Materialize(w)
base := NewLinear(w, nil)
lora := NewLoRALinear(base, 4, 8.0)
opt := NewAdamW(0.001)
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32)
Materialize(x, target)
var initialLoss, finalLoss float64
for step := 0; step < 50; step++ {
lossFn := func(inputs []*Array) []*Array {
lora.A = inputs[0]
lora.B = inputs[1]
pred := lora.Forward(x)
return []*Array{MSELoss(pred, target)}
}
grad := ValueAndGrad(lossFn, 0, 1)
values, grads, err := grad.Apply(lora.A, lora.B)
grad.Free()
if err != nil {
t.Fatalf("step %d failed: %v", step, err)
}
Materialize(append(values, grads...)...)
loss := values[0].Float()
if step == 0 {
initialLoss = loss
}
if step == 49 {
finalLoss = loss
}
updated := opt.Step([]*Array{lora.A, lora.B}, grads)
lora.A = updated[0]
lora.B = updated[1]
Materialize(lora.A, lora.B)
}
t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss)
if finalLoss >= initialLoss {
t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss)
}
}

View file

@ -1,46 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// RandomCategorical samples from a categorical distribution defined by logprobs.
// Returns indices sampled according to the log-probability distribution along the last axis.
func RandomCategorical(logprobs *Array) *Array {
out := New("RANDOM_CATEGORICAL", logprobs)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_categorical(
&out.ctx,
logprobs.ctx,
C.int(-1), // axis
key, // null key = use default RNG
DefaultStream().ctx,
)
return out
}
// RandomUniform generates uniform random values in [low, high).
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
out := New("RANDOM_UNIFORM")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
lo := FromValue(low)
hi := FromValue(high)
key := C.mlx_array_new()
defer C.mlx_array_free(key)
C.mlx_random_uniform(
&out.ctx,
lo.ctx, hi.ctx,
&cShape[0], C.size_t(len(cShape)),
C.mlx_dtype(dtype),
key,
DefaultStream().ctx,
)
return out
}

View file

@ -1,90 +0,0 @@
//go:build darwin && arm64
// Package sample provides composable token sampling strategies.
package sample
import (
"math"
"forge.lthn.ai/core/go-ai/mlx"
)
// Sampler transforms logits into a sampled token index.
type Sampler interface {
Sample(logits *mlx.Array) *mlx.Array
}
// New creates a composable sampler chain from the given parameters.
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
func New(temp, topP, minP float32, topK int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if topP > 0 && topP < 1 {
samplers = append(samplers, TopP(topP))
}
if minP > 0 {
samplers = append(samplers, MinPSampler(minP))
}
if topK > 0 {
samplers = append(samplers, TopKSampler(topK))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
// chain applies a sequence of samplers, then samples from the result.
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, s := range c {
logits = s.Sample(logits)
}
// Final categorical sample from log-probabilities
return mlx.RandomCategorical(logits)
}
// greedy returns the argmax token.
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return mlx.Argmax(logits, -1, false)
}
// Temperature scales logits by 1/temp.
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return mlx.MulScalar(logits, 1.0/float32(t))
}
// TopKSampler masks all but the top-k logits.
type TopKSampler int
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
neg := mlx.Negative(logits)
mask := mlx.Argpartition(neg, int(k)-1, -1)
// Slice the indices beyond top-k
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
// TopP implements nucleus sampling (cumulative probability threshold).
type TopP float32
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
// For now, pass through. TopK + Temperature covers most use cases.
return logits
}
// MinPSampler masks tokens below min_p * max_prob.
type MinPSampler float32
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
// For now, pass through — MinP is an optimization over TopP.
// Full implementation requires finding max prob and masking below threshold.
return logits
}

View file

@ -1,63 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// Slice extracts a sub-array using start and end indices for each dimension.
// starts and ends must have the same length as the array's dimensions.
func Slice(a *Array, starts, ends []int32) *Array {
out := New("SLICE", a)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}
// SliceAxis extracts a sub-array along a single axis.
func SliceAxis(a *Array, axis int, start, end int32) *Array {
// Build full slice parameters
ndim := a.NumDims()
starts := make([]int32, ndim)
ends := make([]int32, ndim)
for i := 0; i < ndim; i++ {
starts[i] = 0
ends[i] = int32(a.Dim(i))
}
ax := axis
if ax < 0 {
ax = ndim + ax
}
starts[ax] = start
ends[ax] = end
return Slice(a, starts, ends)
}
// SliceUpdateInplace updates a slice of the array in-place.
// This is critical for KV cache updates.
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
out := New("SLICE_UPDATE", a, update)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
return out
}

View file

@ -1,79 +0,0 @@
//go:build darwin && arm64
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
import "sync"
// Stream wraps an mlx_stream handle for dispatching operations.
type Stream struct {
ctx C.mlx_stream
}
var (
defaultStream *Stream
defaultStreamOnce sync.Once
)
// DefaultStream returns the default GPU stream, creating it on first use.
func DefaultStream() *Stream {
defaultStreamOnce.Do(func() {
Init()
defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
})
return defaultStream
}
// DefaultGPUStream returns a new GPU stream.
func DefaultGPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
}
// DefaultCPUStream returns a new CPU stream.
func DefaultCPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
}
// Synchronize waits for all operations on the stream to complete.
func Synchronize(s *Stream) {
C.mlx_synchronize(s.ctx)
}
// SetMemoryLimit sets the Metal memory limit. Returns the previous limit.
func SetMemoryLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_memory_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// SetCacheLimit sets the Metal cache limit. Returns the previous limit.
func SetCacheLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_cache_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// GetActiveMemory returns the current Metal memory usage in bytes.
func GetActiveMemory() uint64 {
var mem C.size_t
C.mlx_get_active_memory(&mem)
return uint64(mem)
}
// GetPeakMemory returns the peak Metal memory usage in bytes.
func GetPeakMemory() uint64 {
var mem C.size_t
C.mlx_get_peak_memory(&mem)
return uint64(mem)
}
// ClearCache releases Metal memory held in the MLX allocator cache.
func ClearCache() {
C.mlx_clear_cache()
}

View file

@ -1,324 +0,0 @@
//go:build darwin && arm64
// Package tokenizer provides BPE tokenization for transformer models.
package tokenizer
import (
"encoding/json"
"fmt"
"os"
"strings"
)
// Tokenizer handles text-to-token and token-to-text conversion.
type Tokenizer struct {
vocab map[string]int32
invVocab map[int32]string
merges []mergePair
special map[string]int32
bosToken int32
eosToken int32
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
isGPT2BPE bool
gpt2Decoder map[rune]byte // Unicode char → original byte
gpt2Encoder map[byte]rune // original byte → Unicode char
}
type mergePair struct {
a, b string
rank int
}
// tokenizerJSON is the HuggingFace tokenizer.json format.
type tokenizerJSON struct {
Model struct {
Type string `json:"type"`
Vocab json.RawMessage `json:"vocab"`
Merges json.RawMessage `json:"merges"`
ByteFallback bool `json:"byte_fallback"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// Load reads a tokenizer.json file and creates a Tokenizer.
func Load(path string) (*Tokenizer, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("tokenizer: parse: %w", err)
}
t := &Tokenizer{
vocab: make(map[string]int32),
invVocab: make(map[int32]string),
special: make(map[string]int32),
}
// Parse vocab
var vocab map[string]int32
if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil {
return nil, fmt.Errorf("tokenizer: parse vocab: %w", err)
}
t.vocab = vocab
for k, v := range vocab {
t.invVocab[v] = k
}
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
if len(tj.Model.Merges) > 0 {
var stringMerges []string
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
for rank, merge := range stringMerges {
parts := strings.SplitN(merge, " ", 2)
if len(parts) == 2 {
t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank})
}
}
} else {
var arrayMerges [][]string
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
for rank, pair := range arrayMerges {
if len(pair) == 2 {
t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank})
}
}
}
}
}
// Parse special tokens
for _, tok := range tj.AddedTokens {
if tok.Special {
t.special[tok.Content] = tok.ID
}
t.vocab[tok.Content] = tok.ID
t.invVocab[tok.ID] = tok.Content
}
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
if _, ok := t.vocab["Ġ"]; ok {
t.isGPT2BPE = true
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
}
// Set BOS/EOS — detect model family from special tokens
if id, ok := t.special["<bos>"]; ok {
t.bosToken = id
}
if id, ok := t.special["<eos>"]; ok {
t.eosToken = id
}
// Gemma: <end_of_turn> is the generation stop token
if id, ok := t.special["<end_of_turn>"]; ok {
t.eosToken = id
}
// Qwen3: <|im_end|> is the generation stop token
if id, ok := t.special["<|im_end|>"]; ok {
t.eosToken = id
}
// Qwen3 BOS: <|im_start|>
if id, ok := t.special["<|im_start|>"]; ok {
t.bosToken = id
}
return t, nil
}
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
encoder = make(map[byte]rune, 256)
decoder = make(map[rune]byte, 256)
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
// Use int loop variable to avoid byte overflow at 255.
selfMap := func(lo, hi int) {
for b := lo; b <= hi; b++ {
encoder[byte(b)] = rune(b)
decoder[rune(b)] = byte(b)
}
}
selfMap(33, 126) // ! through ~
selfMap(161, 172) // ¡ through ¬
selfMap(174, 255) // ® through ÿ
// Non-self-mapping: control chars, space, DEL, and gaps
n := 0
for b := 0; b < 256; b++ {
if _, ok := encoder[byte(b)]; !ok {
r := rune(256 + n)
encoder[byte(b)] = r
decoder[r] = byte(b)
n++
}
}
return
}
// Encode converts text to token IDs. Prepends BOS token.
func (t *Tokenizer) Encode(text string) []int32 {
tokens := []int32{t.bosToken}
if t.isGPT2BPE {
return t.encodeGPT2(text)
}
// SentencePiece style encoding
remaining := text
for remaining != "" {
found := false
for tok, id := range t.special {
if strings.HasPrefix(remaining, tok) {
tokens = append(tokens, id)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
r := []rune(remaining)
ch := "▁" + string(r[0])
if id, ok := t.vocab[ch]; ok {
tokens = append(tokens, id)
} else if id, ok := t.vocab[string(r[0])]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}
return tokens
}
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
func (t *Tokenizer) encodeGPT2(text string) []int32 {
tokens := []int32{t.bosToken}
// Convert text bytes to GPT-2 Unicode representation
var encoded strings.Builder
for _, b := range []byte(text) {
if r, ok := t.gpt2Encoder[b]; ok {
encoded.WriteRune(r)
}
}
gpt2Text := encoded.String()
// Scan for special tokens and regular text
remaining := gpt2Text
for remaining != "" {
// Check special tokens (these are stored as-is, not byte-encoded)
found := false
for tok, id := range t.special {
// Special tokens in GPT-2 tokenizers are stored in their original form
// Convert the special token to GPT-2 encoding for matching
var encTok strings.Builder
for _, b := range []byte(tok) {
if r, ok := t.gpt2Encoder[b]; ok {
encTok.WriteRune(r)
}
}
encStr := encTok.String()
if strings.HasPrefix(remaining, encStr) {
tokens = append(tokens, id)
remaining = remaining[len(encStr):]
found = true
break
}
}
if !found {
// Character-by-character lookup (simplified BPE)
r := []rune(remaining)
ch := string(r[0])
if id, ok := t.vocab[ch]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}
return tokens
}
// Decode converts token IDs back to text.
// For full-sequence decoding, the SentencePiece leading space is stripped.
func (t *Tokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if text, ok := t.invVocab[id]; ok {
// Skip special tokens in decode output
if _, isSpecial := t.special[text]; isSpecial {
continue
}
sb.WriteString(text)
}
}
raw := sb.String()
if t.isGPT2BPE {
return t.decodeGPT2Bytes(raw)
}
// SentencePiece style
result := strings.ReplaceAll(raw, "▁", " ")
if strings.HasPrefix(result, " ") {
result = result[1:]
}
return result
}
// DecodeToken converts a single token ID to text for streaming.
// Unlike Decode, it preserves the leading space (word boundary) so that
// token-by-token output maintains correct spacing between words.
func (t *Tokenizer) DecodeToken(id int32) string {
text, ok := t.invVocab[id]
if !ok {
return ""
}
if _, isSpecial := t.special[text]; isSpecial {
return ""
}
if t.isGPT2BPE {
return t.decodeGPT2Bytes(text)
}
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
return strings.ReplaceAll(text, "▁", " ")
}
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
var buf []byte
for _, r := range s {
if b, ok := t.gpt2Decoder[r]; ok {
buf = append(buf, b)
} else {
// Non-mapped runes pass through as UTF-8
buf = append(buf, []byte(string(r))...)
}
}
return string(buf)
}
// BOSToken returns the beginning-of-sequence token ID.
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
// EOSToken returns the end-of-sequence token ID.
func (t *Tokenizer) EOSToken() int32 { return t.eosToken }
// FormatGemmaPrompt applies the Gemma 3 chat template.
func FormatGemmaPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}