diff --git a/go.mod b/go.mod index 43eeb76..a96a835 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.5 require ( forge.lthn.ai/core/go v0.0.0 + forge.lthn.ai/core/go-mlx v0.0.0 github.com/gorilla/websocket v1.5.3 github.com/marcboeker/go-duckdb v1.8.5 github.com/modelcontextprotocol/go-sdk v1.3.0 @@ -54,3 +55,5 @@ require ( ) replace forge.lthn.ai/core/go => ../core + +replace forge.lthn.ai/core/go-mlx => ../go-mlx diff --git a/ml/backend_mlx.go b/ml/backend_mlx.go index 8783f9d..d7596a0 100644 --- a/ml/backend_mlx.go +++ b/ml/backend_mlx.go @@ -10,11 +10,11 @@ import ( "strings" "sync" - "forge.lthn.ai/core/go-ai/mlx" - "forge.lthn.ai/core/go-ai/mlx/cache" - "forge.lthn.ai/core/go-ai/mlx/model" - "forge.lthn.ai/core/go-ai/mlx/sample" - "forge.lthn.ai/core/go-ai/mlx/tokenizer" + "forge.lthn.ai/core/go-mlx" + "forge.lthn.ai/core/go-mlx/cache" + "forge.lthn.ai/core/go-mlx/model" + "forge.lthn.ai/core/go-mlx/sample" + "forge.lthn.ai/core/go-mlx/tokenizer" ) // MLXBackend implements Backend and StreamingBackend for native Metal inference. diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt deleted file mode 100644 index e1cf221..0000000 --- a/mlx/CMakeLists.txt +++ /dev/null @@ -1,28 +0,0 @@ -cmake_minimum_required(VERSION 3.24) - -project(mlx) - -set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0" CACHE STRING "Minimum macOS version") - -if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) - set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE) -endif() - -set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE) -set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE) -set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) -set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE) - -set(CMAKE_INSTALL_RPATH "@loader_path") - -include(FetchContent) - -set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "") - -FetchContent_Declare( - mlx-c - GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" - GIT_TAG ${MLX_C_GIT_TAG} -) - -FetchContent_MakeAvailable(mlx-c) diff --git a/mlx/array.go b/mlx/array.go deleted file mode 100644 index 0968bda..0000000 --- a/mlx/array.go +++ /dev/null @@ -1,261 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include -#include "mlx/c/mlx.h" -*/ -import "C" - -import ( - "encoding/binary" - "reflect" - "runtime" - "strings" - "unsafe" -) - -// Array wraps an mlx_array handle. -// Memory management relies on Go GC finalizers to call mlx_array_free, -// which decrements MLX-C's internal reference count. MLX-C handles all -// cross-array references internally — the Go wrapper does not track them. -type Array struct { - ctx C.mlx_array - name string // debug label -} - -// New creates a named Array and registers a GC finalizer. -// The inputs parameter is accepted for API compatibility but not stored — -// MLX-C tracks inter-array references via its own refcounting. -func New(name string, inputs ...*Array) *Array { - t := &Array{name: name} - runtime.SetFinalizer(t, finalizeArray) - return t -} - -// finalizeArray is called by Go GC to release the underlying C array handle. -func finalizeArray(t *Array) { - if t != nil && t.ctx.ctx != nil { - C.mlx_array_free(t.ctx) - t.ctx.ctx = nil - } -} - -type scalarTypes interface { - ~bool | ~int | ~float32 | ~float64 | ~complex64 -} - -// FromValue creates a scalar Array from a Go value. -func FromValue[T scalarTypes](t T) *Array { - Init() - tt := New("") - switch v := any(t).(type) { - case bool: - tt.ctx = C.mlx_array_new_bool(C.bool(v)) - case int: - tt.ctx = C.mlx_array_new_int(C.int(v)) - case float32: - tt.ctx = C.mlx_array_new_float32(C.float(v)) - case float64: - tt.ctx = C.mlx_array_new_float64(C.double(v)) - case complex64: - tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v))) - default: - panic("mlx: unsupported scalar type") - } - return tt -} - -type arrayTypes interface { - ~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 | - ~int8 | ~int16 | ~int32 | ~int64 | - ~float32 | ~float64 | - ~complex64 -} - -// FromValues creates an Array from a Go slice with the given shape. -func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array { - Init() - if len(shape) == 0 { - panic("mlx: shape required for non-scalar tensors") - } - - cShape := make([]C.int, len(shape)) - for i := range shape { - cShape[i] = C.int(shape[i]) - } - - var dtype DType - switch reflect.TypeOf(s).Elem().Kind() { - case reflect.Bool: - dtype = DTypeBool - case reflect.Uint8: - dtype = DTypeUint8 - case reflect.Uint16: - dtype = DTypeUint16 - case reflect.Uint32: - dtype = DTypeUint32 - case reflect.Uint64: - dtype = DTypeUint64 - case reflect.Int8: - dtype = DTypeInt8 - case reflect.Int16: - dtype = DTypeInt16 - case reflect.Int32: - dtype = DTypeInt32 - case reflect.Int64: - dtype = DTypeInt64 - case reflect.Float32: - dtype = DTypeFloat32 - case reflect.Float64: - dtype = DTypeFloat64 - case reflect.Complex64: - dtype = DTypeComplex64 - default: - panic("mlx: unsupported element type") - } - - bts := make([]byte, binary.Size(s)) - if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil { - panic(err) - } - - tt := New("") - tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype)) - return tt -} - -// Zeros creates a zero-filled Array with the given shape and dtype. -func Zeros(shape []int32, dtype DType) *Array { - Init() - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - tt := New("ZEROS") - C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx) - return tt -} - -// Set replaces this array's C handle with another's. -func (t *Array) Set(other *Array) { - C.mlx_array_set(&t.ctx, other.ctx) -} - -// Clone creates a new Go wrapper sharing the same C handle (increments C refcount). -func (t *Array) Clone() *Array { - tt := New(t.name) - C.mlx_array_set(&tt.ctx, t.ctx) - return tt -} - -// Valid reports whether this Array has a non-nil mlx handle. -func (t *Array) Valid() bool { - return t.ctx.ctx != nil -} - -// String returns a human-readable representation of the array. -func (t *Array) String() string { - str := C.mlx_string_new() - defer C.mlx_string_free(str) - C.mlx_array_tostring(&str, t.ctx) - return strings.TrimSpace(C.GoString(C.mlx_string_data(str))) -} - -// Shape returns the dimensions as int32 slice. -func (t *Array) Shape() []int32 { - dims := make([]int32, t.NumDims()) - for i := range dims { - dims[i] = int32(t.Dim(i)) - } - return dims -} - -// Size returns the total number of elements. -func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) } - -// NumBytes returns the total byte size. -func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) } - -// NumDims returns the number of dimensions. -func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) } - -// Dim returns the size of dimension i. -func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) } - -// Dims returns all dimensions as int slice. -func (t Array) Dims() []int { - dims := make([]int, t.NumDims()) - for i := range dims { - dims[i] = t.Dim(i) - } - return dims -} - -// Dtype returns the array's data type. -func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) } - -// Int extracts a scalar int64 value. -func (t Array) Int() int { - var item C.int64_t - C.mlx_array_item_int64(&item, t.ctx) - return int(item) -} - -// Float extracts a scalar float64 value. -// Handles both float32 and float64 array dtypes. -func (t Array) Float() float64 { - switch t.Dtype() { - case DTypeFloat32: - var item C.float - C.mlx_array_item_float32(&item, t.ctx) - return float64(item) - default: - var item C.double - C.mlx_array_item_float64(&item, t.ctx) - return float64(item) - } -} - -// Ints extracts all elements as int slice (from int32 data). -func (t Array) Ints() []int { - ints := make([]int, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) { - ints[i] = int(f) - } - return ints -} - -// DataInt32 extracts all elements as int32 slice. -func (t Array) DataInt32() []int32 { - data := make([]int32, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) { - data[i] = int32(f) - } - return data -} - -// Floats extracts all elements as float32 slice. -func (t Array) Floats() []float32 { - floats := make([]float32, t.Size()) - for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) { - floats[i] = float32(f) - } - return floats -} - -// Free explicitly releases C array handles. Does not cascade — MLX-C's -// internal refcounting handles dependent arrays automatically. -func Free(s ...*Array) int { - var n int - for _, t := range s { - if t != nil && t.Valid() { - n += t.NumBytes() - C.mlx_array_free(t.ctx) - t.ctx.ctx = nil - runtime.SetFinalizer(t, nil) // cancel finalizer - } - } - return n -} diff --git a/mlx/cache/cache.go b/mlx/cache/cache.go deleted file mode 100644 index ced8b39..0000000 --- a/mlx/cache/cache.go +++ /dev/null @@ -1,201 +0,0 @@ -//go:build darwin && arm64 - -// Package cache provides KV cache implementations for transformer inference. -package cache - -import "forge.lthn.ai/core/go-ai/mlx" - -// Cache manages key-value pairs for transformer attention layers. -type Cache interface { - // Update adds new key/value tensors and returns the full cached K/V. - Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) - // Offset returns the total number of tokens processed. - Offset() int - // Len returns the number of cached tokens (may differ from Offset for rotating caches). - Len() int - // State returns the cached K/V arrays, or nil if empty. - State() []*mlx.Array - // Reset clears the cache for a new generation session. - Reset() -} - -// KVCache implements an unbounded cache that grows as needed. -// Pre-allocates in chunks of `step` tokens to reduce allocations. -type KVCache struct { - keys, values *mlx.Array - offset int - step int -} - -// NewKVCache creates a new unbounded KV cache with 256-token chunks. -func NewKVCache() *KVCache { - return &KVCache{step: 256} -} - -func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { - prev := c.offset - shape := k.Shape() - if len(shape) < 4 { - // K/V must be [B, H, L, D] — if not, pass through unchanged - if c.keys == nil { - c.keys, c.values = k, v - } - c.offset += seqLen - return c.keys, c.values - } - B, H, Dk := shape[0], shape[1], shape[3] - Dv := v.Shape()[3] - - // Grow buffer if needed. - if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) { - nSteps := (c.step + seqLen - 1) / c.step - newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype()) - newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) - - if c.keys != nil { - if prev%c.step != 0 { - c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) - c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) - } - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) - } else { - c.keys, c.values = newK, newV - } - } - - c.offset += seqLen - c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) - c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) - - return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), - mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) -} - -func (c *KVCache) State() []*mlx.Array { - if c.keys == nil { - return nil - } - return []*mlx.Array{c.keys, c.values} -} - -func (c *KVCache) Offset() int { return c.offset } -func (c *KVCache) Len() int { return c.offset } - -func (c *KVCache) Reset() { - c.keys = nil - c.values = nil - c.offset = 0 -} - -// RotatingKVCache implements a bounded sliding window cache. -type RotatingKVCache struct { - keys, values *mlx.Array - offset int - maxSize int - step int - idx int -} - -// NewRotatingKVCache creates a cache bounded to maxSize tokens. -func NewRotatingKVCache(maxSize int) *RotatingKVCache { - return &RotatingKVCache{maxSize: maxSize, step: 256} -} - -func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { - if seqLen > 1 { - return c.updateConcat(k, v, seqLen) - } - return c.updateInPlace(k, v) -} - -func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) { - shape := k.Shape() - if len(shape) < 4 { - if c.keys == nil { - c.keys, c.values = k, v - } - c.offset++ - return c.keys, c.values - } - B, H, Dk := shape[0], shape[1], shape[3] - Dv := v.Shape()[3] - - if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) { - var cap int - if c.keys != nil { - cap = int(c.keys.Shape()[2]) - } - newSize := min(c.step, c.maxSize-cap) - newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) - newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) - if c.keys != nil { - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2) - } else { - c.keys, c.values = newK, newV - } - } - - if c.idx >= c.maxSize { - c.idx = 0 - } - - c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) - c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) - - c.offset++ - c.idx++ - - validLen := int32(min(c.offset, c.maxSize)) - return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}), - mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv}) -} - -func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) { - shape := k.Shape() - if len(shape) < 4 { - // K/V must be [B, H, L, D] — if not, pass through unchanged - if c.keys == nil { - c.keys, c.values = k, v - } - c.offset += seqLen - return c.keys, c.values - } - B, H, Dk := shape[0], shape[1], shape[3] - Dv := v.Shape()[3] - - if c.keys == nil { - c.keys, c.values = k, v - } else { - c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2) - c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2) - } - c.offset += seqLen - - cap := int(c.keys.Shape()[2]) - if trim := cap - c.maxSize; trim > 0 { - c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) - c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) - } - - c.idx = int(c.keys.Shape()[2]) - return c.keys, c.values -} - -func (c *RotatingKVCache) State() []*mlx.Array { - if c.keys == nil { - return nil - } - return []*mlx.Array{c.keys, c.values} -} - -func (c *RotatingKVCache) Offset() int { return c.offset } -func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) } - -func (c *RotatingKVCache) Reset() { - c.keys = nil - c.values = nil - c.offset = 0 - c.idx = 0 -} diff --git a/mlx/compile.go b/mlx/compile.go deleted file mode 100644 index 8440f4b..0000000 --- a/mlx/compile.go +++ /dev/null @@ -1,90 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include "mlx/c/mlx.h" - -// Callback for compiled functions. -extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); - -static mlx_closure new_closure(void *payload) { - return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL); -} -*/ -import "C" - -import ( - "sync" - "unsafe" -) - -// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls. -type CompiledFunc struct { - fn func([]*Array) []*Array - closure C.mlx_closure - mu sync.Mutex -} - -var compiledFuncs sync.Map - -//export goCompiledFunc -func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { - id := uintptr(payload) - fnI, ok := compiledFuncs.Load(id) - if !ok { - return 1 - } - fn := fnI.(func([]*Array) []*Array) - - // Convert inputs - nInputs := int(C.mlx_vector_array_size(inputs)) - goInputs := make([]*Array, nInputs) - for i := 0; i < nInputs; i++ { - a := New("INPUT") - C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) - goInputs[i] = a - } - - // Call user function - goOutputs := fn(goInputs) - - // The output vector arrives with ctx=nullptr. Create a valid vector, - // fill it, then copy to the output pointer. - tmp := C.mlx_vector_array_new() - for _, out := range goOutputs { - C.mlx_vector_array_append_value(tmp, out.ctx) - } - C.mlx_vector_array_set(outputs, tmp) - C.mlx_vector_array_free(tmp) - return 0 -} - -var nextID uintptr -var nextIDMu sync.Mutex - -// CompileShapeless compiles a function for efficient repeated execution. -// The function must accept and return arrays of consistent shapes. -func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc { - nextIDMu.Lock() - nextID++ - id := nextID - nextIDMu.Unlock() - - compiledFuncs.Store(id, fn) - - cf := &CompiledFunc{fn: fn} - cf.closure = C.new_closure(unsafe.Pointer(id)) - return cf -} - -// Call executes the compiled function with the given inputs. -func (cf *CompiledFunc) Call(inputs ...*Array) []*Array { - cf.mu.Lock() - defer cf.mu.Unlock() - - // Fall back to direct call — compilation is an optimization. - // The compiled closure can be used via mlx_compiled but the - // direct path is simpler and still benefits from MLX's lazy evaluation. - return cf.fn(inputs) -} diff --git a/mlx/dtype.go b/mlx/dtype.go deleted file mode 100644 index eae583f..0000000 --- a/mlx/dtype.go +++ /dev/null @@ -1,83 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -// #include "mlx/c/mlx.h" -import "C" - -import "encoding/json" - -// DType represents an MLX array data type. -type DType C.mlx_dtype - -const ( - DTypeBool DType = C.MLX_BOOL - DTypeUint8 DType = C.MLX_UINT8 - DTypeUint16 DType = C.MLX_UINT16 - DTypeUint32 DType = C.MLX_UINT32 - DTypeUint64 DType = C.MLX_UINT64 - DTypeInt8 DType = C.MLX_INT8 - DTypeInt16 DType = C.MLX_INT16 - DTypeInt32 DType = C.MLX_INT32 - DTypeInt64 DType = C.MLX_INT64 - DTypeFloat16 DType = C.MLX_FLOAT16 - DTypeFloat32 DType = C.MLX_FLOAT32 - DTypeFloat64 DType = C.MLX_FLOAT64 - DTypeBFloat16 DType = C.MLX_BFLOAT16 - DTypeComplex64 DType = C.MLX_COMPLEX64 -) - -var dtypeNames = map[DType]string{ - DTypeBool: "bool", - DTypeUint8: "uint8", - DTypeUint16: "uint16", - DTypeUint32: "uint32", - DTypeUint64: "uint64", - DTypeInt8: "int8", - DTypeInt16: "int16", - DTypeInt32: "int32", - DTypeInt64: "int64", - DTypeFloat16: "float16", - DTypeFloat32: "float32", - DTypeFloat64: "float64", - DTypeBFloat16: "bfloat16", - DTypeComplex64: "complex64", -} - -func (d DType) String() string { - if s, ok := dtypeNames[d]; ok { - return s - } - return "unknown" -} - -var dtypeFromString = map[string]DType{ - "bool": DTypeBool, "BOOL": DTypeBool, - "uint8": DTypeUint8, "U8": DTypeUint8, - "uint16": DTypeUint16, "U16": DTypeUint16, - "uint32": DTypeUint32, "U32": DTypeUint32, - "uint64": DTypeUint64, "U64": DTypeUint64, - "int8": DTypeInt8, "I8": DTypeInt8, - "int16": DTypeInt16, "I16": DTypeInt16, - "int32": DTypeInt32, "I32": DTypeInt32, - "int64": DTypeInt64, "I64": DTypeInt64, - "float16": DTypeFloat16, "F16": DTypeFloat16, - "float32": DTypeFloat32, "F32": DTypeFloat32, - "float64": DTypeFloat64, "F64": DTypeFloat64, - "bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16, - "complex64": DTypeComplex64, -} - -// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc. -func (d *DType) UnmarshalJSON(b []byte) error { - var s string - if err := json.Unmarshal(b, &s); err != nil { - return err - } - if dt, ok := dtypeFromString[s]; ok { - *d = dt - return nil - } - *d = DTypeFloat32 // default - return nil -} diff --git a/mlx/fast.go b/mlx/fast.go deleted file mode 100644 index e1abeba..0000000 --- a/mlx/fast.go +++ /dev/null @@ -1,79 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include -#include "mlx/c/mlx.h" -*/ -import "C" - -import "unsafe" - -// RMSNorm applies Root Mean Square normalization using a fused Metal kernel. -func RMSNorm(x, weight *Array, eps float32) *Array { - out := New("FAST_RMSNORM", x) - C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx) - return out -} - -// LayerNorm applies Layer normalization using a fused Metal kernel. -func LayerNorm(x, weight, bias *Array, eps float32) *Array { - out := New("FAST_LAYERNORM", x) - C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx) - return out -} - -// RoPE applies Rotary Position Embeddings using a fused Metal kernel. -func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array { - out := New("FAST_ROPE", x) - freqs := C.mlx_array_new() - defer C.mlx_array_free(freqs) - C.mlx_fast_rope( - &out.ctx, - x.ctx, - C.int(dims), - C._Bool(traditional), - C.mlx_optional_float{ - value: C.float(base), - has_value: C._Bool(base != 0), - }, - C.float(scale), - C.int(offset), - freqs, - DefaultStream().ctx, - ) - return out -} - -// ScaledDotProductAttention computes attention using a fused Metal kernel. -func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array { - mode := "" - if causal { - mode = "causal" - } - cMode := C.CString(mode) - defer C.free(unsafe.Pointer(cMode)) - - maskArr := C.mlx_array_new() - defer C.mlx_array_free(maskArr) - sinksArr := C.mlx_array_new() - defer C.mlx_array_free(sinksArr) - - out := New("FAST_SDPA", query, key, value) - C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx) - return out -} - -// ScaledDotProductAttentionWithMask computes attention with an explicit mask. -func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array { - cMode := C.CString("array") - defer C.free(unsafe.Pointer(cMode)) - - sinksArr := C.mlx_array_new() - defer C.mlx_array_free(sinksArr) - - out := New("FAST_SDPA", query, key, value, mask) - C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx) - return out -} diff --git a/mlx/grad.go b/mlx/grad.go deleted file mode 100644 index 12ca786..0000000 --- a/mlx/grad.go +++ /dev/null @@ -1,353 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include "mlx/c/mlx.h" - -// Callback for gradient closures — same signature as goCompiledFunc. -extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload); - -static mlx_closure new_grad_closure(void *payload) { - return mlx_closure_new_func_payload(&goGradFunc, payload, NULL); -} -*/ -import "C" - -import ( - "fmt" - "sync" - "sync/atomic" - "unsafe" -) - -// --- Closure registry (separate from compile.go's registry) --- - -var ( - gradFuncs sync.Map - gradNextID atomic.Uintptr -) - -//export goGradFunc -func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int { - id := uintptr(payload) - fnI, ok := gradFuncs.Load(id) - if !ok { - return 1 - } - fn := fnI.(func([]*Array) []*Array) - - nInputs := int(C.mlx_vector_array_size(inputs)) - goInputs := make([]*Array, nInputs) - for i := 0; i < nInputs; i++ { - a := New("GRAD_INPUT") - C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i)) - goInputs[i] = a - } - - goOutputs := fn(goInputs) - - // The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_). - // Create a valid vector, fill it, then copy to the output pointer. - tmp := C.mlx_vector_array_new() - for _, out := range goOutputs { - C.mlx_vector_array_append_value(tmp, out.ctx) - } - C.mlx_vector_array_set(outputs, tmp) - C.mlx_vector_array_free(tmp) - return 0 -} - -// newClosure registers a Go function as an MLX closure for autograd. -func newClosure(fn func([]*Array) []*Array) C.mlx_closure { - id := gradNextID.Add(1) - gradFuncs.Store(id, fn) - return C.new_grad_closure(unsafe.Pointer(id)) -} - -// --- VJP (Vector-Jacobian Product) — Reverse Mode --- - -// VJP computes the vector-Jacobian product (reverse-mode autodiff). -// Given a function fn, input primals, and output cotangents (upstream gradients), -// returns (outputs, gradients) where gradients are w.r.t. the primals. -// -// This is the fundamental backward pass operation. -func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) { - Init() - - closure := newClosure(fn) - defer C.mlx_closure_free(closure) - - // Pack primals into vector - primalsVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(primalsVec) - for _, p := range primals { - C.mlx_vector_array_append_value(primalsVec, p.ctx) - } - - // Pack cotangents into vector - cotangentsVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(cotangentsVec) - for _, c := range cotangents { - C.mlx_vector_array_append_value(cotangentsVec, c.ctx) - } - - // Call mlx_vjp - var outVec, vjpVec C.mlx_vector_array - outVec = C.mlx_vector_array_new() - vjpVec = C.mlx_vector_array_new() - defer C.mlx_vector_array_free(outVec) - defer C.mlx_vector_array_free(vjpVec) - - rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec) - if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: vjp failed") - } - - outputs = vectorToArrays(outVec) - vjps = vectorToArrays(vjpVec) - return outputs, vjps, nil -} - -// --- JVP (Jacobian-Vector Product) — Forward Mode --- - -// JVP computes the Jacobian-vector product (forward-mode autodiff). -// Given a function fn, input primals, and input tangents (perturbation directions), -// returns (outputs, output_tangents). -// -// Useful for directional derivatives and Hessian-vector products. -func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) { - Init() - - closure := newClosure(fn) - defer C.mlx_closure_free(closure) - - primalsVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(primalsVec) - for _, p := range primals { - C.mlx_vector_array_append_value(primalsVec, p.ctx) - } - - tangentsVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(tangentsVec) - for _, t := range tangents { - C.mlx_vector_array_append_value(tangentsVec, t.ctx) - } - - var outVec, jvpVec C.mlx_vector_array - outVec = C.mlx_vector_array_new() - jvpVec = C.mlx_vector_array_new() - defer C.mlx_vector_array_free(outVec) - defer C.mlx_vector_array_free(jvpVec) - - rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec) - if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: jvp failed") - } - - outputs = vectorToArrays(outVec) - jvps = vectorToArrays(jvpVec) - return outputs, jvps, nil -} - -// --- ValueAndGrad — Combined Forward + Backward --- - -// GradFn is a function that computes both the loss value and gradients -// with respect to specified arguments. Call it with inputs to get -// (values, gradients). -type GradFn struct { - cls C.mlx_closure_value_and_grad -} - -// ValueAndGrad creates a GradFn that computes both the function value and -// gradients with respect to the arguments at the given indices. -// -// The returned GradFn can be called repeatedly with different inputs. -// This is the primary API for training loops: -// -// lossFn := func(params []*Array) []*Array { return []*Array{loss} } -// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg -// values, grads := grad.Apply(params...) -func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn { - Init() - - closure := newClosure(fn) - defer C.mlx_closure_free(closure) - - // Default: differentiate w.r.t. first argument - if len(argnums) == 0 { - argnums = []int{0} - } - - cArgs := make([]C.int, len(argnums)) - for i, a := range argnums { - cArgs[i] = C.int(a) - } - - g := &GradFn{} - C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs))) - return g -} - -// Apply calls the gradient function with the given inputs. -// Returns (values, gradients) where values are the function outputs -// and gradients are w.r.t. the arguments specified in ValueAndGrad. -func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) { - inputVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(inputVec) - for _, in := range inputs { - C.mlx_vector_array_append_value(inputVec, in.ctx) - } - - var valVec, gradVec C.mlx_vector_array - valVec = C.mlx_vector_array_new() - gradVec = C.mlx_vector_array_new() - defer C.mlx_vector_array_free(valVec) - defer C.mlx_vector_array_free(gradVec) - - rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec) - if rc != 0 { - checkError() - return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed") - } - - values = vectorToArrays(valVec) - grads = vectorToArrays(gradVec) - return values, grads, nil -} - -// Free releases the underlying C closure. -func (g *GradFn) Free() { - if g.cls.ctx != nil { - C.mlx_closure_value_and_grad_free(g.cls) - g.cls.ctx = nil - } -} - -// --- Checkpoint — Memory-Efficient Gradient Recomputation --- - -// Checkpoint wraps a function so that during backward pass, intermediate -// activations are recomputed rather than stored. Trades compute for memory. -// -// Use this for memory-constrained training (large models on limited VRAM). -func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array { - Init() - - closure := newClosure(fn) - defer C.mlx_closure_free(closure) - - var checkpointed C.mlx_closure - C.mlx_checkpoint(&checkpointed, closure) - - // Register the checkpointed closure so it stays alive - id := gradNextID.Add(1) - cpClosure := checkpointed // capture for the returned function - - // Return a Go function that calls through the checkpointed closure - return func(inputs []*Array) []*Array { - _ = id // prevent GC of closure reference - - inputVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(inputVec) - for _, in := range inputs { - C.mlx_vector_array_append_value(inputVec, in.ctx) - } - - outVec := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(outVec) - - C.mlx_closure_apply(&outVec, cpClosure, inputVec) - return vectorToArrays(outVec) - } -} - -// --- Loss Functions --- - -// CrossEntropyLoss computes cross-entropy loss between logits and integer targets. -// logits: [..., V] (raw model output, pre-softmax, last dim = vocab) -// targets: [...] (integer token IDs, same shape as logits minus last dim) -// Returns scalar loss averaged over all positions. -// -// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i] -func CrossEntropyLoss(logits, targets *Array) *Array { - Init() - // LogSumExp along last axis (vocab dimension) for numerical stability - lse := LogSumExp(logits, -1, false) - - // Gather the logit at the target index: logits[..., target] - // targets needs to be [..., 1] for take_along_axis - tgtExpanded := ExpandDims(targets, -1) - gathered := TakeAlongAxis(logits, tgtExpanded, -1) - gathered = Squeeze(gathered, -1) // back to [...] shape - - // Per-position loss: logsumexp - logit_at_target - perPos := Subtract(lse, gathered) - - // Mean over all positions - return MeanAll(perPos) -} - -// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions. -// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore) -// Returns scalar loss averaged over masked positions only. -func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array { - Init() - lse := LogSumExp(logits, -1, false) - tgtExpanded := ExpandDims(targets, -1) - gathered := TakeAlongAxis(logits, tgtExpanded, -1) - gathered = Squeeze(gathered, -1) - perPos := Subtract(lse, gathered) // [B, L] - - // Apply mask and average over masked positions - masked := Mul(perPos, mask) - return Divide(SumAll(masked), SumAll(mask)) -} - -// MSELoss computes the mean squared error loss: mean((predictions - targets)^2). -func MSELoss(predictions, targets *Array) *Array { - diff := Subtract(predictions, targets) - sq := Square(diff) - return MeanAll(sq) -} - -// --- Helpers --- - -// vectorToArrays extracts all arrays from an mlx_vector_array. -func vectorToArrays(vec C.mlx_vector_array) []*Array { - n := int(C.mlx_vector_array_size(vec)) - out := make([]*Array, n) - for i := 0; i < n; i++ { - a := New("VEC_OUT") - C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i)) - out[i] = a - } - return out -} - -// Log returns element-wise natural logarithm. -func Log(a *Array) *Array { - out := New("LOG", a) - C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// SumAll reduces by summation over all elements, returning a scalar. -func SumAll(a *Array) *Array { - flat := Reshape(a, -1) - return Sum(flat, 0, false) -} - -// MeanAll reduces by averaging over all elements, returning a scalar. -func MeanAll(a *Array) *Array { - flat := Reshape(a, -1) - return Mean(flat, 0, false) -} - -// OnesLike creates an array of ones with the same shape and type as the input. -func OnesLike(a *Array) *Array { - out := New("ONES_LIKE", a) - C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} diff --git a/mlx/grad_test.go b/mlx/grad_test.go deleted file mode 100644 index 0a00751..0000000 --- a/mlx/grad_test.go +++ /dev/null @@ -1,332 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -import ( - "math" - "testing" -) - -func TestVJP_SimpleSquare(t *testing.T) { - // f(x) = x^2, df/dx = 2x - // At x=3: f(3)=9, df/dx=6 - fn := func(inputs []*Array) []*Array { - x := inputs[0] - return []*Array{Mul(x, x)} - } - - x := FromValue(float32(3.0)) - cotangent := FromValue(float32(1.0)) // upstream grad = 1 - - outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent}) - if err != nil { - t.Fatalf("VJP failed: %v", err) - } - - Materialize(outputs[0], grads[0]) - - got := outputs[0].Float() - if math.Abs(got-9.0) > 1e-5 { - t.Errorf("output = %f, want 9.0", got) - } - - grad := grads[0].Float() - if math.Abs(grad-6.0) > 1e-5 { - t.Errorf("grad = %f, want 6.0", grad) - } -} - -func TestVJP_Addition(t *testing.T) { - // f(x, y) = x + y, df/dx = 1, df/dy = 1 - fn := func(inputs []*Array) []*Array { - return []*Array{Add(inputs[0], inputs[1])} - } - - x := FromValue(float32(2.0)) - y := FromValue(float32(5.0)) - cotangent := FromValue(float32(1.0)) - - _, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent}) - if err != nil { - t.Fatalf("VJP failed: %v", err) - } - - Materialize(grads...) - - if math.Abs(grads[0].Float()-1.0) > 1e-5 { - t.Errorf("dx = %f, want 1.0", grads[0].Float()) - } - if math.Abs(grads[1].Float()-1.0) > 1e-5 { - t.Errorf("dy = %f, want 1.0", grads[1].Float()) - } -} - -func TestVJP_MatmulGrad(t *testing.T) { - // f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W - // For W=[2,2], x=[2,1]: dL/dW = ones @ x^T - x := FromValues([]float32{1.0, 2.0}, 2, 1) - w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity - - fn := func(inputs []*Array) []*Array { - result := Matmul(inputs[0], x) - return []*Array{SumAll(result)} - } - - cotangent := FromValue(float32(1.0)) - - outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent}) - if err != nil { - t.Fatalf("VJP failed: %v", err) - } - - Materialize(outputs[0], grads[0]) - - // W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3 - got := outputs[0].Float() - if math.Abs(got-3.0) > 1e-5 { - t.Errorf("output = %f, want 3.0", got) - } - - // Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T - // = [[1,2],[1,2]] - gradFloats := grads[0].Floats() - expected := []float32{1.0, 2.0, 1.0, 2.0} - for i, exp := range expected { - if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 { - t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp) - } - } -} - -func TestJVP_SimpleSquare(t *testing.T) { - // f(x) = x^2, JVP with tangent v: df = 2x * v - // At x=3, v=1: df = 6 - fn := func(inputs []*Array) []*Array { - x := inputs[0] - return []*Array{Mul(x, x)} - } - - x := FromValue(float32(3.0)) - tangent := FromValue(float32(1.0)) - - outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent}) - if err != nil { - t.Fatalf("JVP failed: %v", err) - } - - Materialize(outputs[0], jvps[0]) - - got := outputs[0].Float() - if math.Abs(got-9.0) > 1e-5 { - t.Errorf("output = %f, want 9.0", got) - } - - jvp := jvps[0].Float() - if math.Abs(jvp-6.0) > 1e-5 { - t.Errorf("jvp = %f, want 6.0", jvp) - } -} - -func TestValueAndGrad_Quadratic(t *testing.T) { - // f(x) = x^2 + 2x + 1 = (x+1)^2 - // f'(x) = 2x + 2 - // At x=3: f(3) = 16, f'(3) = 8 - fn := func(inputs []*Array) []*Array { - x := inputs[0] - x2 := Mul(x, x) - two_x := MulScalar(x, 2.0) - one := FromValue(float32(1.0)) - return []*Array{Add(Add(x2, two_x), one)} - } - - grad := ValueAndGrad(fn, 0) - defer grad.Free() - - x := FromValue(float32(3.0)) - values, grads, err := grad.Apply(x) - if err != nil { - t.Fatalf("ValueAndGrad failed: %v", err) - } - - Materialize(values[0], grads[0]) - - val := values[0].Float() - if math.Abs(val-16.0) > 1e-5 { - t.Errorf("value = %f, want 16.0", val) - } - - g := grads[0].Float() - if math.Abs(g-8.0) > 1e-5 { - t.Errorf("grad = %f, want 8.0", g) - } -} - -func TestValueAndGrad_MultiArg(t *testing.T) { - // f(x, y) = x*y, df/dx = y, df/dy = x - // At x=3, y=4: f=12, dx=4, dy=3 - fn := func(inputs []*Array) []*Array { - return []*Array{Mul(inputs[0], inputs[1])} - } - - // Differentiate w.r.t. both arguments - grad := ValueAndGrad(fn, 0, 1) - defer grad.Free() - - x := FromValue(float32(3.0)) - y := FromValue(float32(4.0)) - values, grads, err := grad.Apply(x, y) - if err != nil { - t.Fatalf("ValueAndGrad failed: %v", err) - } - - Materialize(values[0], grads[0], grads[1]) - - val := values[0].Float() - if math.Abs(val-12.0) > 1e-5 { - t.Errorf("value = %f, want 12.0", val) - } - - dx := grads[0].Float() - if math.Abs(dx-4.0) > 1e-5 { - t.Errorf("dx = %f, want 4.0 (y)", dx) - } - - dy := grads[1].Float() - if math.Abs(dy-3.0) > 1e-5 { - t.Errorf("dy = %f, want 3.0 (x)", dy) - } -} - -func TestValueAndGrad_Reusable(t *testing.T) { - // Verify GradFn can be called multiple times - fn := func(inputs []*Array) []*Array { - x := inputs[0] - return []*Array{Mul(x, x)} // x^2, grad = 2x - } - - grad := ValueAndGrad(fn) - defer grad.Free() - - for _, tc := range []struct { - x float32 - want float64 // expected gradient - }{ - {2.0, 4.0}, - {5.0, 10.0}, - {-3.0, -6.0}, - {0.0, 0.0}, - } { - x := FromValue(tc.x) - _, grads, err := grad.Apply(x) - if err != nil { - t.Fatalf("Apply failed for x=%f: %v", tc.x, err) - } - Materialize(grads[0]) - - g := grads[0].Float() - if math.Abs(g-tc.want) > 1e-5 { - t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want) - } - } -} - -func TestCrossEntropyLoss(t *testing.T) { - // Simple 3-class classification - // logits = [1.0, 2.0, 3.0], target = 2 (class index) - // Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1) - // = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076 - // loss = 3.4076 - 3.0 = 0.4076 - logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3] - targets := FromValues([]int32{2}, 1) // [1] - - loss := CrossEntropyLoss(logits, targets) - Materialize(loss) - - got := loss.Float() - expected := 0.4076 - if math.Abs(got-expected) > 0.01 { - t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected) - } -} - -func TestMSELoss(t *testing.T) { - pred := FromValues([]float32{1.0, 2.0, 3.0}, 3) - target := FromValues([]float32{1.5, 2.5, 3.5}, 3) - - loss := MSELoss(pred, target) - Materialize(loss) - - // MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25 - got := loss.Float() - if math.Abs(got-0.25) > 1e-5 { - t.Errorf("MSELoss = %f, want 0.25", got) - } -} - -func TestLogSumExp(t *testing.T) { - // logsumexp([1, 2, 3]) along axis -1 - a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) - result := LogSumExp(a, -1, false) - Materialize(result) - - // = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076 - got := result.Float() - expected := 3.4076 - if math.Abs(got-expected) > 0.01 { - t.Errorf("LogSumExp = %f, want ~%f", got, expected) - } -} - -func TestOnesLike(t *testing.T) { - a := FromValues([]float32{1.0, 2.0, 3.0}, 3) - ones := OnesLike(a) - Materialize(ones) - - floats := ones.Floats() - for i, f := range floats { - if f != 1.0 { - t.Errorf("OnesLike[%d] = %f, want 1.0", i, f) - } - } -} - -func TestCheckpoint(t *testing.T) { - // Checkpoint should produce the same result as the original function - fn := func(inputs []*Array) []*Array { - x := inputs[0] - return []*Array{Mul(x, x)} - } - - cpFn := Checkpoint(fn) - - x := FromValue(float32(5.0)) - result := cpFn([]*Array{x}) - Materialize(result[0]) - - got := result[0].Float() - if math.Abs(got-25.0) > 1e-5 { - t.Errorf("Checkpoint result = %f, want 25.0", got) - } -} - -func TestSumAll(t *testing.T) { - a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2) - result := SumAll(a) - Materialize(result) - - got := result.Float() - if math.Abs(got-10.0) > 1e-5 { - t.Errorf("SumAll = %f, want 10.0", got) - } -} - -func TestMeanAll(t *testing.T) { - a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2) - result := MeanAll(a) - Materialize(result) - - got := result.Float() - if math.Abs(got-5.0) > 1e-5 { - t.Errorf("MeanAll = %f, want 5.0", got) - } -} diff --git a/mlx/io.go b/mlx/io.go deleted file mode 100644 index 7e35773..0000000 --- a/mlx/io.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include -#include "mlx/c/mlx.h" -*/ -import "C" - -import ( - "iter" - "runtime" - "unsafe" -) - -// LoadSafetensors loads tensors from a .safetensors file, returning an iterator -// over (name, array) pairs. Tensors are loaded lazily on the CPU stream. -func LoadSafetensors(path string) iter.Seq2[string, *Array] { - Init() - return func(yield func(string, *Array) bool) { - string2array := C.mlx_map_string_to_array_new() - defer C.mlx_map_string_to_array_free(string2array) - - string2string := C.mlx_map_string_to_string_new() - defer C.mlx_map_string_to_string_free(string2string) - - cPath := C.CString(path) - defer C.free(unsafe.Pointer(cPath)) - - cpu := C.mlx_default_cpu_stream_new() - defer C.mlx_stream_free(cpu) - - C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu) - - it := C.mlx_map_string_to_array_iterator_new(string2array) - defer C.mlx_map_string_to_array_iterator_free(it) - - for { - var key *C.char - value := C.mlx_array_new() - if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 { - break - } - - name := C.GoString(key) - arr := &Array{ctx: value, name: name} - runtime.SetFinalizer(arr, finalizeArray) - if !yield(name, arr) { - break - } - } - } -} - -// LoadAllSafetensors loads all tensors from a .safetensors file into a map. -func LoadAllSafetensors(path string) map[string]*Array { - tensors := make(map[string]*Array) - for name, arr := range LoadSafetensors(path) { - tensors[name] = arr - } - return tensors -} diff --git a/mlx/lora.go b/mlx/lora.go deleted file mode 100644 index ac7107e..0000000 --- a/mlx/lora.go +++ /dev/null @@ -1,237 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include -#include "mlx/c/mlx.h" -*/ -import "C" - -import ( - "fmt" - "math" - "sort" - "unsafe" -) - -// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters. -// -// Forward: base(x) + scale * (x @ A^T) @ B^T -// -// A is [rank, in_features] — initialised with Kaiming normal -// B is [out_features, rank] — initialised to zero -// Scale = alpha / rank -// -// Only A and B are trainable. The base weights stay frozen. -type LoRALinear struct { - Base *Linear // Frozen base weights (may be quantized) - A *Array // [rank, in_features] — trainable - B *Array // [out_features, rank] — trainable - Scale float32 // alpha / rank - Rank int - Alpha float32 -} - -// NewLoRALinear wraps an existing Linear layer with LoRA adapters. -// rank: decomposition rank (typically 4, 8, or 16) -// alpha: scaling factor (typically 2*rank) -func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear { - // Determine dimensions from the base weight. - // Weight shape is [out_features, in_features] for standard linear, - // or quantized shape which we handle via the base layer. - var inFeatures, outFeatures int32 - - if base.Scales != nil { - // Quantized: weight is packed. Compute dims from scales. - // scales shape is [out_features, in_features / group_size] - scaleShape := base.Scales.Shape() - outFeatures = scaleShape[0] - inFeatures = scaleShape[1] * int32(base.GroupSize) - } else { - wShape := base.Weight.Shape() - outFeatures = wShape[0] - inFeatures = wShape[1] - } - - // A: Kaiming normal initialisation — N(0, 1/sqrt(in_features)) - stddev := float32(1.0 / math.Sqrt(float64(inFeatures))) - a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32) - - // B: zero initialisation — LoRA starts as identity (no change to base) - b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32) - - Materialize(a, b) - - return &LoRALinear{ - Base: base, - A: a, - B: b, - Scale: alpha / float32(rank), - Rank: rank, - Alpha: alpha, - } -} - -// Forward computes: base(x) + scale * (x @ A^T) @ B^T -// Calls baseForward on the underlying Linear to avoid infinite recursion -// when the Linear's Forward method dispatches through LoRA. -func (l *LoRALinear) Forward(x *Array) *Array { - baseOut := l.Base.baseForward(x) - - // LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out] - loraOut := Matmul(x, Transpose(l.A)) - loraOut = Matmul(loraOut, Transpose(l.B)) - loraOut = MulScalar(loraOut, l.Scale) - - return Add(baseOut, loraOut) -} - -// TrainableParams returns the LoRA A and B arrays for gradient computation. -func (l *LoRALinear) TrainableParams() []*Array { - return []*Array{l.A, l.B} -} - -// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step). -func (l *LoRALinear) SetParams(a, b *Array) { - l.A = a - l.B = b -} - -// ParamCount returns the number of trainable parameters. -func (l *LoRALinear) ParamCount() int { - aShape := l.A.Shape() - bShape := l.B.Shape() - return int(aShape[0]*aShape[1] + bShape[0]*bShape[1]) -} - -// --- LoRA Application to Models --- - -// LoRAConfig specifies which layers to apply LoRA to and with what parameters. -type LoRAConfig struct { - Rank int // Decomposition rank (default 8) - Alpha float32 // Scaling factor (default 16) - TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj) -} - -// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning. -func DefaultLoRAConfig() LoRAConfig { - return LoRAConfig{ - Rank: 8, - Alpha: 16, - TargetKeys: []string{"q_proj", "v_proj"}, - } -} - -// LoRAAdapter holds all LoRA layers applied to a model. -type LoRAAdapter struct { - Layers map[string]*LoRALinear // keyed by weight path prefix - Config LoRAConfig -} - -// TotalParams returns the total number of trainable parameters across all LoRA layers. -func (a *LoRAAdapter) TotalParams() int { - total := 0 - for _, l := range a.Layers { - total += l.ParamCount() - } - return total -} - -// SortedNames returns layer names in deterministic sorted order. -func (a *LoRAAdapter) SortedNames() []string { - names := make([]string, 0, len(a.Layers)) - for name := range a.Layers { - names = append(names, name) - } - sort.Strings(names) - return names -} - -// AllTrainableParams returns all trainable arrays (A and B from every layer), -// in a deterministic order sorted by layer name. -func (a *LoRAAdapter) AllTrainableParams() []*Array { - names := a.SortedNames() - params := make([]*Array, 0, len(names)*2) - for _, name := range names { - l := a.Layers[name] - params = append(params, l.A, l.B) - } - return params -} - -// SetAllParams updates all LoRA A and B arrays from a flat slice, -// in the same deterministic order as AllTrainableParams. -func (a *LoRAAdapter) SetAllParams(params []*Array) { - names := a.SortedNames() - for i, name := range names { - l := a.Layers[name] - l.A = params[i*2] - l.B = params[i*2+1] - } -} - -// Save writes the LoRA adapter weights to a safetensors file. -// Only saves the A and B matrices — not the frozen base weights. -func (a *LoRAAdapter) Save(path string) error { - weights := make(map[string]*Array) - for name, l := range a.Layers { - weights[name+".lora_a"] = l.A - weights[name+".lora_b"] = l.B - } - return SaveSafetensors(path, weights) -} - -// --- Random Normal --- - -// RandomNormal generates normal (Gaussian) random values with given mean and stddev. -func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array { - Init() - out := New("RANDOM_NORMAL") - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - key := C.mlx_array_new() - defer C.mlx_array_free(key) - C.mlx_random_normal( - &out.ctx, - &cShape[0], C.size_t(len(cShape)), - C.mlx_dtype(dtype), - C.float(mean), C.float(stddev), - key, // null key = use default RNG - DefaultStream().ctx, - ) - return out -} - -// --- SaveSafetensors --- - -// SaveSafetensors saves a map of named arrays to a .safetensors file. -func SaveSafetensors(path string, weights map[string]*Array) error { - Init() - - // Build the map - cMap := C.mlx_map_string_to_array_new() - defer C.mlx_map_string_to_array_free(cMap) - - for name, arr := range weights { - cName := C.CString(name) - C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx) - C.free(unsafe.Pointer(cName)) - } - - // Empty metadata - cMeta := C.mlx_map_string_to_string_new() - defer C.mlx_map_string_to_string_free(cMeta) - - cPath := C.CString(path) - defer C.free(unsafe.Pointer(cPath)) - - rc := C.mlx_save_safetensors(cPath, cMap, cMeta) - if rc != 0 { - checkError() - return fmt.Errorf("mlx: save safetensors failed: %s", path) - } - return nil -} diff --git a/mlx/lora_test.go b/mlx/lora_test.go deleted file mode 100644 index 729dffa..0000000 --- a/mlx/lora_test.go +++ /dev/null @@ -1,316 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -import ( - "math" - "os" - "testing" -) - -func TestNewLoRALinear(t *testing.T) { - // Create a simple base linear layer: [4, 8] weight - w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 4, 8.0) // rank=4, alpha=8 - - // Check dimensions - aShape := lora.A.Shape() - bShape := lora.B.Shape() - - if aShape[0] != 4 || aShape[1] != 8 { - t.Errorf("A shape = %v, want [4, 8]", aShape) - } - if bShape[0] != 4 || bShape[1] != 4 { - t.Errorf("B shape = %v, want [4, 4]", bShape) - } - - // Scale should be alpha/rank = 8/4 = 2 - if math.Abs(float64(lora.Scale)-2.0) > 1e-5 { - t.Errorf("Scale = %f, want 2.0", lora.Scale) - } - - // B should be all zeros (LoRA starts as identity) - Materialize(lora.B) - bFloats := lora.B.Floats() - for i, v := range bFloats { - if v != 0 { - t.Errorf("B[%d] = %f, want 0", i, v) - } - } -} - -func TestLoRALinear_ForwardMatchesBase(t *testing.T) { - // With B=0, LoRA forward should equal base forward - w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 4, 8.0) - - // Random input [1, 3, 8] - x := RandomNormal(0, 1, []int32{1, 3, 8}, DTypeFloat32) - Materialize(x) - - baseOut := base.Forward(x) - loraOut := lora.Forward(x) - Materialize(baseOut, loraOut) - - // Should be identical since B is zero - baseFloats := baseOut.Floats() - loraFloats := loraOut.Floats() - - if len(baseFloats) != len(loraFloats) { - t.Fatalf("output sizes differ: base=%d, lora=%d", len(baseFloats), len(loraFloats)) - } - - for i := range baseFloats { - diff := math.Abs(float64(baseFloats[i] - loraFloats[i])) - if diff > 1e-4 { - t.Errorf("output[%d] differs: base=%f, lora=%f", i, baseFloats[i], loraFloats[i]) - } - } -} - -func TestLoRALinear_ForwardWithAdapter(t *testing.T) { - // Set A and B to known values and verify output changes - w := Zeros([]int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 2, 4.0) // rank=2, alpha=4, scale=2 - - // Set A to identity-like: [[1,0,0,...], [0,1,0,...]] - a := Zeros([]int32{2, 8}, DTypeFloat32) - // Set B to ones: [[1,1], [1,1], [1,1], [1,1]] - b := FromValues([]float32{ - 1, 1, - 1, 1, - 1, 1, - 1, 1, - }, 4, 2) - Materialize(a, b) - lora.A = a - lora.B = b - - // With base=0, A=0, output should also be 0 (scale * x@0@B^T = 0) - x := FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 8) - result := lora.Forward(x) - Materialize(result) - - // base(x) = 0 (zero weights), lora = scale * (x @ A^T) @ B^T - // A is zeros, so x @ A^T = [0, 0], then @ B^T = [0,0,0,0] - for _, v := range result.Floats() { - if v != 0 { - t.Errorf("expected 0 with zero A, got %f", v) - } - } -} - -func TestLoRALinear_ParamCount(t *testing.T) { - w := RandomNormal(0, 0.01, []int32{64, 128}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 8, 16.0) // rank=8 - // A: [8, 128] = 1024, B: [64, 8] = 512, total = 1536 - expected := 8*128 + 64*8 - if lora.ParamCount() != expected { - t.Errorf("ParamCount = %d, want %d", lora.ParamCount(), expected) - } -} - -func TestLoRALinear_TrainableParams(t *testing.T) { - w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 4, 8.0) - params := lora.TrainableParams() - - if len(params) != 2 { - t.Fatalf("TrainableParams returned %d arrays, want 2", len(params)) - } - - // First is A, second is B - if params[0].Shape()[0] != 4 || params[0].Shape()[1] != 8 { - t.Errorf("param[0] (A) shape = %v, want [4, 8]", params[0].Shape()) - } - if params[1].Shape()[0] != 4 || params[1].Shape()[1] != 4 { - t.Errorf("param[1] (B) shape = %v, want [4, 4]", params[1].Shape()) - } -} - -func TestLoRALinear_GradientFlows(t *testing.T) { - // Verify that gradients flow through the LoRA path - w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 4, 8.0) - x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32) - Materialize(x) - - // Loss function: sum of LoRA output (differentiating w.r.t. A and B) - lossFn := func(inputs []*Array) []*Array { - lora.A = inputs[0] - lora.B = inputs[1] - out := lora.Forward(x) - return []*Array{SumAll(out)} - } - - grad := ValueAndGrad(lossFn, 0, 1) // grad w.r.t. A and B - defer grad.Free() - - values, grads, err := grad.Apply(lora.A, lora.B) - if err != nil { - t.Fatalf("ValueAndGrad failed: %v", err) - } - - Materialize(append(values, grads...)...) - - // Loss should be a scalar - loss := values[0].Float() - t.Logf("loss = %f", loss) - - // Gradients should be non-zero (A has random init, B is zero but gets grad) - gradA := grads[0] - gradB := grads[1] - - aGradFloats := gradA.Floats() - bGradFloats := gradB.Floats() - - hasNonZeroA := false - for _, v := range aGradFloats { - if v != 0 { - hasNonZeroA = true - break - } - } - - hasNonZeroB := false - for _, v := range bGradFloats { - if v != 0 { - hasNonZeroB = true - break - } - } - - // A gradient might be zero if B is zero (since dL/dA depends on B) - // But B gradient should be non-zero since A is random - if !hasNonZeroB { - t.Error("gradient for B is all zeros — gradients not flowing") - } - t.Logf("gradA has non-zero: %v, gradB has non-zero: %v", hasNonZeroA, hasNonZeroB) -} - -func TestRandomNormal(t *testing.T) { - arr := RandomNormal(0, 1, []int32{100}, DTypeFloat32) - Materialize(arr) - - floats := arr.Floats() - if len(floats) != 100 { - t.Fatalf("RandomNormal returned %d elements, want 100", len(floats)) - } - - // Check rough statistics: mean should be near 0, values should have spread - var sum float64 - for _, f := range floats { - sum += float64(f) - } - mean := sum / 100 - if math.Abs(mean) > 0.5 { // generous tolerance for 100 samples - t.Errorf("mean = %f, expected near 0", mean) - } -} - -func TestSaveSafetensors(t *testing.T) { - a := FromValues([]float32{1, 2, 3, 4}, 2, 2) - b := FromValues([]float32{5, 6, 7, 8, 9, 10}, 3, 2) - Materialize(a, b) - - path := t.TempDir() + "/test.safetensors" - err := SaveSafetensors(path, map[string]*Array{ - "layer.lora_a": a, - "layer.lora_b": b, - }) - if err != nil { - t.Fatalf("SaveSafetensors failed: %v", err) - } - - // Verify file exists - info, err := os.Stat(path) - if err != nil { - t.Fatalf("saved file not found: %v", err) - } - if info.Size() == 0 { - t.Error("saved file is empty") - } - - // Load it back - loaded := LoadAllSafetensors(path) - Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"]) - - aLoaded := loaded["layer.lora_a"].Floats() - bLoaded := loaded["layer.lora_b"].Floats() - - expectedA := []float32{1, 2, 3, 4} - expectedB := []float32{5, 6, 7, 8, 9, 10} - - for i, v := range expectedA { - if aLoaded[i] != v { - t.Errorf("loaded A[%d] = %f, want %f", i, aLoaded[i], v) - } - } - for i, v := range expectedB { - if bLoaded[i] != v { - t.Errorf("loaded B[%d] = %f, want %f", i, bLoaded[i], v) - } - } -} - -func TestLoRAAdapter_Save(t *testing.T) { - w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - adapter := &LoRAAdapter{ - Layers: map[string]*LoRALinear{ - "model.layers.0.self_attn.q_proj": NewLoRALinear(base, 4, 8.0), - }, - Config: DefaultLoRAConfig(), - } - - path := t.TempDir() + "/adapter.safetensors" - err := adapter.Save(path) - if err != nil { - t.Fatalf("Adapter.Save failed: %v", err) - } - - // Load and verify - loaded := LoadAllSafetensors(path) - aKey := "model.layers.0.self_attn.q_proj.lora_a" - bKey := "model.layers.0.self_attn.q_proj.lora_b" - - if _, ok := loaded[aKey]; !ok { - t.Errorf("missing key %s in saved adapter", aKey) - } - if _, ok := loaded[bKey]; !ok { - t.Errorf("missing key %s in saved adapter", bKey) - } -} - -func TestDefaultLoRAConfig(t *testing.T) { - cfg := DefaultLoRAConfig() - if cfg.Rank != 8 { - t.Errorf("Rank = %d, want 8", cfg.Rank) - } - if cfg.Alpha != 16 { - t.Errorf("Alpha = %f, want 16", cfg.Alpha) - } - if len(cfg.TargetKeys) != 2 { - t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys) - } -} diff --git a/mlx/mlx.go b/mlx/mlx.go deleted file mode 100644 index 470e1f7..0000000 --- a/mlx/mlx.go +++ /dev/null @@ -1,115 +0,0 @@ -//go:build darwin && arm64 - -// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. -// -// Build mlx-c before use: -// -// cd pkg/mlx && go generate ./... -// -// Build (MLX is auto-enabled on darwin/arm64): -// -// go build -o core . -package mlx - -//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release -//go:generate cmake --build build --parallel -//go:generate cmake --install build - -/* -#cgo CXXFLAGS: -std=c++17 -#cgo CFLAGS: -mmacosx-version-min=26.0 -#cgo CPPFLAGS: -I${SRCDIR}/dist/include -#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx -#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate -#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib - -#include -#include -#include "mlx/c/mlx.h" - -static const char *last_mlx_error = NULL; - -static void mlx_go_error_handler(const char *msg, void *data) { - fprintf(stderr, "MLX ERROR: %s\n", msg); - last_mlx_error = msg; -} - -static void set_error_handler() { - mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL); -} - -static const char* get_last_error() { - return last_mlx_error; -} -*/ -import "C" - -import ( - "log/slog" - "sync" -) - -var initOnce sync.Once - -// Init sets up the MLX error handler. Called automatically on first use. -func Init() { - initOnce.Do(func() { - C.set_error_handler() - slog.Debug("mlx: initialized with Metal backend") - }) -} - -// checkError logs the last MLX error if any occurred. -func checkError() { - if msg := C.get_last_error(); msg != nil { - slog.Error("mlx", "error", C.GoString(msg)) - } -} - -// Materialize synchronously evaluates arrays, computing their values on the GPU. -// This is the MLX equivalent of forcing lazy computation to complete. -func Materialize(outputs ...*Array) { - doMaterialize(outputs, false) -} - -// MaterializeAsync queues arrays for asynchronous GPU evaluation. -func MaterializeAsync(outputs ...*Array) { - doMaterialize(outputs, true) -} - -func doMaterialize(outputs []*Array, async bool) { - Init() - vector := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(vector) - - for _, output := range outputs { - if output != nil && output.Valid() { - C.mlx_vector_array_append_value(vector, output.ctx) - } - } - - if async { - C.mlx_async_eval(vector) - } else { - C.mlx_eval(vector) - } -} - -// Collect gathers all valid arrays from a variadic list for batch Materialize. -func Collect(arrays ...*Array) []*Array { - var out []*Array - for _, a := range arrays { - if a != nil && a.Valid() { - out = append(out, a) - } - } - return out -} - -// MetalAvailable reports whether Metal GPU is available. -func MetalAvailable() bool { - Init() - var available C.bool - C.mlx_metal_is_available(&available) - return bool(available) -} diff --git a/mlx/mlx_stub.go b/mlx/mlx_stub.go deleted file mode 100644 index 281c8cb..0000000 --- a/mlx/mlx_stub.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !(darwin && arm64) - -// Package mlx provides Go bindings for Apple's MLX framework via mlx-c. -// This stub file is used on non-darwin/non-arm64 platforms or when the -// mlx build tag is not set. All operations report MLX as unavailable. -package mlx - -// MetalAvailable reports whether Metal GPU is available. -// Always returns false on non-Apple Silicon platforms. -func MetalAvailable() bool { return false } diff --git a/mlx/model/gemma3.go b/mlx/model/gemma3.go deleted file mode 100644 index 849df4c..0000000 --- a/mlx/model/gemma3.go +++ /dev/null @@ -1,448 +0,0 @@ -//go:build darwin && arm64 - -// Package model provides transformer model architectures for MLX inference. -package model - -import ( - "encoding/json" - "fmt" - "log/slog" - "math" - "os" - "path/filepath" - - "forge.lthn.ai/core/go-ai/mlx" - "forge.lthn.ai/core/go-ai/mlx/cache" - "forge.lthn.ai/core/go-ai/mlx/tokenizer" -) - -// TextConfig holds Gemma 3 text model configuration. -type TextConfig struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - HeadDim int32 `json:"head_dim"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - RopeLocalBaseFreq float32 `json:"rope_local_base_freq"` - MaxPositionEmbeddings int32 `json:"max_position_embeddings"` - SlidingWindow int32 `json:"sliding_window"` - SlidingWindowPattern int32 `json:"sliding_window_pattern"` - - Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level - Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim) -} - -// GemmaModel is the Gemma 3 text model. -type GemmaModel struct { - EmbedTokens *mlx.Embedding - Layers []*DecoderLayer - Norm *mlx.RMSNormModule - Output *mlx.Linear // Tied to EmbedTokens - - // Precomputed (1 + weight) for Gemma-style RMSNorm - NormScaled *mlx.Array - - Tok *tokenizer.Tokenizer - Cfg *TextConfig -} - -// DecoderLayer is a single transformer block. -type DecoderLayer struct { - InputNorm *mlx.RMSNormModule - Attention *Attention - PostAttnNorm *mlx.RMSNormModule - PreFFNorm *mlx.RMSNormModule - MLP *MLP - PostFFNorm *mlx.RMSNormModule - - // Precomputed scaled weights - InputNormScaled *mlx.Array - PostAttnNormScaled *mlx.Array - PreFFNormScaled *mlx.Array - PostFFNormScaled *mlx.Array - - IsSliding bool - LayerIdx int32 -} - -// Attention implements Gemma 3 attention with Q/K normalization. -type Attention struct { - QProj *mlx.Linear - KProj *mlx.Linear - VProj *mlx.Linear - OProj *mlx.Linear - QNorm *mlx.RMSNormModule - KNorm *mlx.RMSNormModule - - QNormScaled *mlx.Array - KNormScaled *mlx.Array -} - -// MLP is the feed-forward network. -type MLP struct { - GateProj *mlx.Linear - UpProj *mlx.Linear - DownProj *mlx.Linear -} - -// compiledGELU is a singleton for the compiled GELU function. -var compiledGELU *mlx.CompiledFunc - -func getCompiledGELU() *mlx.CompiledFunc { - if compiledGELU == nil { - compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array { - return []*mlx.Array{geluApprox(inputs[0])} - }, true) - } - return compiledGELU -} - -// geluApprox computes GELU using the tanh approximation: -// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) -func geluApprox(x *mlx.Array) *mlx.Array { - const sqrt2OverPi = 0.7978845608028654 - const coeff = 0.044715 - - x3 := mlx.Mul(mlx.Mul(x, x), x) - inner := mlx.Add(x, mlx.MulScalar(x3, coeff)) - scaled := mlx.MulScalar(inner, sqrt2OverPi) - t := mlx.Tanh(scaled) - onePlusT := mlx.AddScalar(t, 1.0) - return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT) -} - -// parseConfig handles both flat and nested (text_config) Gemma 3 configs. -func parseConfig(data []byte) (*TextConfig, error) { - // Try parsing text_config from multimodal wrapper - var wrapper struct { - TextConfig TextConfig `json:"text_config"` - ModelType string `json:"model_type"` - Quantization *QuantizationConfig `json:"quantization"` - } - if err := json.Unmarshal(data, &wrapper); err != nil { - return nil, err - } - - cfg := wrapper.TextConfig - - // If text_config was empty, try top-level - if cfg.NumHiddenLayers == 0 { - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, err - } - } - - // Quantization is always top-level - cfg.Quantization = wrapper.Quantization - - // Compute scale (head_dim may be inferred later from weights if not in config) - if cfg.HeadDim > 0 { - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - } - if cfg.RopeTheta == 0 { - cfg.RopeTheta = 1000000 - } - if cfg.RopeLocalBaseFreq == 0 { - cfg.RopeLocalBaseFreq = 10000 - } - if cfg.RMSNormEps == 0 { - cfg.RMSNormEps = 1e-6 - } - if cfg.SlidingWindowPattern == 0 { - cfg.SlidingWindowPattern = 6 - } - if cfg.VocabSize == 0 { - cfg.VocabSize = 262208 // Gemma 3 default - } - - return &cfg, nil -} - -// LoadGemma3 loads a Gemma 3 text model from a directory. -func LoadGemma3(modelPath string) (*GemmaModel, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("gemma3: load config: %w", err) - } - - cfg, err := parseConfig(data) - if err != nil { - return nil, fmt.Errorf("gemma3: parse config: %w", err) - } - - // Load tokenizer - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("gemma3: load tokenizer: %w", err) - } - - // Load weights from all safetensors files - weights := make(map[string]*mlx.Array) - matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) - for _, path := range matches { - for name, arr := range mlx.LoadSafetensors(path) { - weights[name] = arr - } - } - - // Helper to resolve weight with language_model. prefix fallback - w := func(name string) *mlx.Array { return resolveWeight(weights, name) } - - // Infer head_dim from q_proj weight shape when not in config. - // Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads. - if cfg.HeadDim == 0 { - qWeight := w("model.layers.0.self_attn.q_proj.weight") - if qWeight != nil { - qShape := qWeight.Shape() - if len(qShape) > 0 { - cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim) - } - } - } - - // Helper to create linear layer (quantized or dense) - q := cfg.Quantization - if q != nil { - slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) - } - linear := func(prefix string) *mlx.Linear { - weight := w(prefix + ".weight") - scales := w(prefix + ".scales") - biases := w(prefix + ".biases") - if scales != nil && q != nil { - return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits) - } - return mlx.NewLinear(weight, nil) - } - - // Create embedding (quantized or dense) - embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} - if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { - embed.Scales = embedScales - embed.Biases = w("model.embed_tokens.biases") - embed.GroupSize = q.GroupSize - embed.Bits = q.Bits - } - - m := &GemmaModel{ - EmbedTokens: embed, - Layers: make([]*DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, - Tok: tok, - Cfg: cfg, - } - - // Initialize layers - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - prefix := fmt.Sprintf("model.layers.%d", i) - m.Layers[i] = &DecoderLayer{ - InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")}, - PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")}, - PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")}, - PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")}, - Attention: &Attention{ - QProj: linear(prefix + ".self_attn.q_proj"), - KProj: linear(prefix + ".self_attn.k_proj"), - VProj: linear(prefix + ".self_attn.v_proj"), - OProj: linear(prefix + ".self_attn.o_proj"), - QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")}, - KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")}, - }, - MLP: &MLP{ - GateProj: linear(prefix + ".mlp.gate_proj"), - UpProj: linear(prefix + ".mlp.up_proj"), - DownProj: linear(prefix + ".mlp.down_proj"), - }, - LayerIdx: i, - IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern), - } - } - - // Output head — check for separate lm_head first, else tie to embeddings - lmHeadWeight := w("lm_head.weight") - if lmHeadWeight != nil { - lmHeadScales := w("lm_head.scales") - if lmHeadScales != nil && q != nil { - m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) - } else { - m.Output = mlx.NewLinear(lmHeadWeight, nil) - } - } else { - // Tied embeddings — reuse embed_tokens weights (with quantization if present) - m.Output = m.EmbedTokens.AsLinear() - } - - // Materialize all weights - var allArrays []*mlx.Array - for _, a := range weights { - allArrays = append(allArrays, a) - } - mlx.Materialize(allArrays...) - - // Precompute (1 + weight) for Gemma-style RMSNorm - precomputeScaledWeights(m) - - return m, nil -} - -func precomputeScaledWeights(m *GemmaModel) { - m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0) - - for _, layer := range m.Layers { - layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0) - layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0) - layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0) - layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0) - layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0) - layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0) - } - - var scaled []*mlx.Array - scaled = append(scaled, m.NormScaled) - for _, layer := range m.Layers { - scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled, - layer.PreFFNormScaled, layer.PostFFNormScaled, - layer.Attention.QNormScaled, layer.Attention.KNormScaled) - } - mlx.Materialize(scaled...) -} - -func isLayerSliding(layerIdx, pattern int32) bool { - if pattern <= 0 { - return false - } - return (layerIdx+1)%pattern != 0 -} - -// Forward runs the text model forward pass. -func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - shape := tokens.Shape() - B, L := shape[0], shape[1] - - h := m.EmbedTokens.Forward(tokens) - h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) - - for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, m.Cfg) - } - - return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) -} - -func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array { - normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) - attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg) - attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) - h := mlx.Add(x, attnOut) - - normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) - mlpOut := l.MLP.forward(normed) - mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) - return mlx.Add(h, mlpOut) -} - -func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - // Reshape to [B, num_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - - // Q/K normalization - q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) - k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) - - // RoPE with appropriate theta - ropeTheta := cfg.RopeTheta - if isSliding { - ropeTheta = cfg.RopeLocalBaseFreq - } - q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) - - // Update cache - k, v = c.Update(k, v, int(L)) - - // GQA: repeat K/V heads - repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads - if repeatFactor > 1 { - k = mlx.RepeatKV(k, repeatFactor) - v = mlx.RepeatKV(v, repeatFactor) - } - - // Scaled dot-product attention - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) -} - -func (m *MLP) forward(x *mlx.Array) *mlx.Array { - gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0] - return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) -} - -// NewCache creates per-layer caches for generation. -func (m *GemmaModel) NewCache() []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i := range caches { - if m.Layers[i].IsSliding { - caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow)) - } else { - caches[i] = cache.NewKVCache() - } - } - return caches -} - -// NumLayers returns the number of transformer layers. -func (m *GemmaModel) NumLayers() int { return len(m.Layers) } - -// Tokenizer returns the model's tokenizer. -func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok } - -// ModelType returns the architecture identifier. -func (m *GemmaModel) ModelType() string { return "gemma3" } - -// ApplyLoRA wraps target projection layers with LoRA adapters. -func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { - adapter := &mlx.LoRAAdapter{ - Layers: make(map[string]*mlx.LoRALinear), - Config: cfg, - } - - for i, layer := range m.Layers { - prefix := fmt.Sprintf("model.layers.%d.self_attn", i) - for _, target := range cfg.TargetKeys { - var proj *mlx.Linear - switch target { - case "q_proj": - proj = layer.Attention.QProj - case "k_proj": - proj = layer.Attention.KProj - case "v_proj": - proj = layer.Attention.VProj - case "o_proj": - proj = layer.Attention.OProj - } - if proj != nil { - lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) - proj.LoRA = lora - adapter.Layers[prefix+"."+target] = lora - } - } - } - - return adapter -} diff --git a/mlx/model/model.go b/mlx/model/model.go deleted file mode 100644 index faf1faa..0000000 --- a/mlx/model/model.go +++ /dev/null @@ -1,78 +0,0 @@ -//go:build darwin && arm64 - -// Package model provides transformer model architectures for MLX inference. -package model - -import ( - "encoding/json" - "fmt" - "os" - "path/filepath" - - "forge.lthn.ai/core/go-ai/mlx" - "forge.lthn.ai/core/go-ai/mlx/cache" - "forge.lthn.ai/core/go-ai/mlx/tokenizer" -) - -// Model is the common interface for all transformer model architectures. -type Model interface { - // Forward runs the model forward pass on token IDs with KV caches. - Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array - - // NewCache creates per-layer KV caches for generation. - NewCache() []cache.Cache - - // NumLayers returns the number of transformer layers. - NumLayers() int - - // Tokenizer returns the model's tokenizer. - Tokenizer() *tokenizer.Tokenizer - - // ModelType returns the architecture identifier (e.g. "gemma3", "qwen3"). - ModelType() string - - // ApplyLoRA wraps target projection layers with LoRA adapters for training. - // Returns the adapter which holds references to all LoRA layers. - ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter -} - -// QuantizationConfig holds quantization parameters from config.json. -type QuantizationConfig struct { - GroupSize int `json:"group_size"` - Bits int `json:"bits"` -} - -// resolveWeight looks up a weight with optional "language_model." prefix. -func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array { - if w, ok := weights[name]; ok { - return w - } - if w, ok := weights["language_model."+name]; ok { - return w - } - return nil -} - -// LoadModel auto-detects the model architecture from config.json and loads it. -func LoadModel(modelPath string) (Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("model: load config: %w", err) - } - - var probe struct { - ModelType string `json:"model_type"` - } - if err := json.Unmarshal(data, &probe); err != nil { - return nil, fmt.Errorf("model: parse model_type: %w", err) - } - - switch probe.ModelType { - case "qwen3": - return LoadQwen3(modelPath) - case "gemma3", "gemma2": - return LoadGemma3(modelPath) - default: - return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType) - } -} diff --git a/mlx/model/qwen3.go b/mlx/model/qwen3.go deleted file mode 100644 index 469f9f7..0000000 --- a/mlx/model/qwen3.go +++ /dev/null @@ -1,337 +0,0 @@ -//go:build darwin && arm64 - -package model - -import ( - "encoding/json" - "fmt" - "log/slog" - "math" - "os" - "path/filepath" - - "forge.lthn.ai/core/go-ai/mlx" - "forge.lthn.ai/core/go-ai/mlx/cache" - "forge.lthn.ai/core/go-ai/mlx/tokenizer" -) - -// Qwen3Config holds Qwen 3 model configuration. -type Qwen3Config struct { - HiddenSize int32 `json:"hidden_size"` - NumHiddenLayers int32 `json:"num_hidden_layers"` - IntermediateSize int32 `json:"intermediate_size"` - NumAttentionHeads int32 `json:"num_attention_heads"` - NumKeyValueHeads int32 `json:"num_key_value_heads"` - HeadDim int32 `json:"head_dim"` - VocabSize int32 `json:"vocab_size"` - RMSNormEps float32 `json:"rms_norm_eps"` - RopeTheta float32 `json:"rope_theta"` - MaxPositionEmbeddings int32 `json:"max_position_embeddings"` - - Quantization *QuantizationConfig `json:"-"` - Scale float32 `json:"-"` // 1/sqrt(head_dim) -} - -// Qwen3Model is the Qwen 3 text model. -type Qwen3Model struct { - EmbedTokens *mlx.Embedding - Layers []*Qwen3DecoderLayer - Norm *mlx.RMSNormModule - Output *mlx.Linear - - Tok *tokenizer.Tokenizer - Cfg *Qwen3Config -} - -// Qwen3DecoderLayer is a single transformer block. -// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add. -type Qwen3DecoderLayer struct { - InputNorm *mlx.RMSNormModule // Pre-attention norm - PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm) - Attention *Qwen3Attention - MLP *Qwen3MLP -} - -// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization. -type Qwen3Attention struct { - QProj *mlx.Linear - KProj *mlx.Linear - VProj *mlx.Linear - OProj *mlx.Linear - QNorm *mlx.RMSNormModule - KNorm *mlx.RMSNormModule -} - -// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)). -type Qwen3MLP struct { - GateProj *mlx.Linear - UpProj *mlx.Linear - DownProj *mlx.Linear -} - -func parseQwen3Config(data []byte) (*Qwen3Config, error) { - var cfg Qwen3Config - if err := json.Unmarshal(data, &cfg); err != nil { - return nil, err - } - - // Top-level quantization - var wrapper struct { - Quantization *QuantizationConfig `json:"quantization"` - } - json.Unmarshal(data, &wrapper) - cfg.Quantization = wrapper.Quantization - - // Compute scale - if cfg.HeadDim == 0 { - cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads - } - cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim))) - - // Defaults - if cfg.RopeTheta == 0 { - cfg.RopeTheta = 1000000 - } - if cfg.RMSNormEps == 0 { - cfg.RMSNormEps = 1e-6 - } - if cfg.VocabSize == 0 { - cfg.VocabSize = 151936 - } - - return &cfg, nil -} - -// LoadQwen3 loads a Qwen 3 model from a safetensors directory. -func LoadQwen3(modelPath string) (*Qwen3Model, error) { - data, err := os.ReadFile(filepath.Join(modelPath, "config.json")) - if err != nil { - return nil, fmt.Errorf("qwen3: load config: %w", err) - } - - cfg, err := parseQwen3Config(data) - if err != nil { - return nil, fmt.Errorf("qwen3: parse config: %w", err) - } - - tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json")) - if err != nil { - return nil, fmt.Errorf("qwen3: load tokenizer: %w", err) - } - - // Load weights from all safetensors files - weights := make(map[string]*mlx.Array) - matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors")) - for _, path := range matches { - for name, arr := range mlx.LoadSafetensors(path) { - weights[name] = arr - } - } - - w := func(name string) *mlx.Array { return resolveWeight(weights, name) } - - // Quantization setup - q := cfg.Quantization - if q != nil { - slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize) - } - linear := func(prefix string) *mlx.Linear { - weight := w(prefix + ".weight") - scales := w(prefix + ".scales") - biases := w(prefix + ".biases") - bias := w(prefix + ".bias") - if scales != nil && q != nil { - return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits) - } - return mlx.NewLinear(weight, bias) - } - - // Embedding - embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")} - if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil { - embed.Scales = embedScales - embed.Biases = w("model.embed_tokens.biases") - embed.GroupSize = q.GroupSize - embed.Bits = q.Bits - } - - m := &Qwen3Model{ - EmbedTokens: embed, - Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers), - Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")}, - Tok: tok, - Cfg: cfg, - } - - for i := int32(0); i < cfg.NumHiddenLayers; i++ { - p := fmt.Sprintf("model.layers.%d", i) - m.Layers[i] = &Qwen3DecoderLayer{ - InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")}, - PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")}, - Attention: &Qwen3Attention{ - QProj: linear(p + ".self_attn.q_proj"), - KProj: linear(p + ".self_attn.k_proj"), - VProj: linear(p + ".self_attn.v_proj"), - OProj: linear(p + ".self_attn.o_proj"), - QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")}, - KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")}, - }, - MLP: &Qwen3MLP{ - GateProj: linear(p + ".mlp.gate_proj"), - UpProj: linear(p + ".mlp.up_proj"), - DownProj: linear(p + ".mlp.down_proj"), - }, - } - } - - // Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate - lmHeadWeight := w("lm_head.weight") - if lmHeadWeight != nil { - lmHeadScales := w("lm_head.scales") - if lmHeadScales != nil && q != nil { - m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits) - } else { - m.Output = mlx.NewLinear(lmHeadWeight, nil) - } - } else { - m.Output = m.EmbedTokens.AsLinear() - } - - // Materialise all weights onto Metal - var allArrays []*mlx.Array - for _, a := range weights { - allArrays = append(allArrays, a) - } - mlx.Materialize(allArrays...) - - slog.Info("qwen3: model loaded", - "layers", cfg.NumHiddenLayers, - "hidden", cfg.HiddenSize, - "heads", cfg.NumAttentionHeads, - "kv_heads", cfg.NumKeyValueHeads, - "head_dim", cfg.HeadDim, - "vocab", cfg.VocabSize, - ) - - return m, nil -} - -// Forward runs the Qwen 3 forward pass. -// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size). -func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array { - shape := tokens.Shape() - B, L := shape[0], shape[1] - - h := m.EmbedTokens.Forward(tokens) - - for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, m.Cfg) - } - - return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) -} - -func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { - // Pre-attention norm → attention → residual add - normed := l.InputNorm.Forward(x, cfg.RMSNormEps) - attnOut := l.Attention.forward(normed, c, B, L, cfg) - h := mlx.Add(x, attnOut) - - // Pre-MLP norm → MLP → residual add - normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps) - mlpOut := l.MLP.forward(normed) - return mlx.Add(h, mlpOut) -} - -func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) - - // Reshape to [B, num_heads, L, head_dim] - q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, - []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - - // Q/K RMS normalization (Qwen 3 has this) - q = a.QNorm.Forward(q, cfg.RMSNormEps) - k = a.KNorm.Forward(k, cfg.RMSNormEps) - - // RoPE — single theta for all layers (no sliding window) - q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) - k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) - - // Update KV cache - k, v = c.Update(k, v, int(L)) - - // GQA: repeat K/V heads to match Q heads - repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads - if repeatFactor > 1 { - k = mlx.RepeatKV(k, repeatFactor) - v = mlx.RepeatKV(v, repeatFactor) - } - - // Scaled dot-product attention - out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) - out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) -} - -// forward computes SwiGLU: down(silu(gate(x)) * up(x)). -func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array { - gate := mlx.SiLU(m.GateProj.Forward(x)) - return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x))) -} - -// NewCache creates per-layer KV caches. Qwen 3 uses global attention only. -func (m *Qwen3Model) NewCache() []cache.Cache { - caches := make([]cache.Cache, len(m.Layers)) - for i := range caches { - caches[i] = cache.NewKVCache() - } - return caches -} - -// NumLayers returns the number of transformer layers. -func (m *Qwen3Model) NumLayers() int { return len(m.Layers) } - -// Tokenizer returns the model's tokenizer. -func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok } - -// ModelType returns the architecture identifier. -func (m *Qwen3Model) ModelType() string { return "qwen3" } - -// ApplyLoRA wraps target projection layers with LoRA adapters. -func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter { - adapter := &mlx.LoRAAdapter{ - Layers: make(map[string]*mlx.LoRALinear), - Config: cfg, - } - - for i, layer := range m.Layers { - prefix := fmt.Sprintf("model.layers.%d.self_attn", i) - for _, target := range cfg.TargetKeys { - var proj *mlx.Linear - switch target { - case "q_proj": - proj = layer.Attention.QProj - case "k_proj": - proj = layer.Attention.KProj - case "v_proj": - proj = layer.Attention.VProj - case "o_proj": - proj = layer.Attention.OProj - } - if proj != nil { - lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha) - proj.LoRA = lora - adapter.Layers[prefix+"."+target] = lora - } - } - } - - return adapter -} diff --git a/mlx/nn.go b/mlx/nn.go deleted file mode 100644 index 4de9db9..0000000 --- a/mlx/nn.go +++ /dev/null @@ -1,115 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -// Linear is a fully-connected layer: y = x @ W.T + bias. -// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul. -// Set LoRA to inject a low-rank adapter (training only). -type Linear struct { - Weight *Array `weight:"weight"` - Scales *Array `weight:"scales"` - Biases *Array `weight:"biases"` - Bias *Array `weight:"bias"` - GroupSize int - Bits int - - LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it -} - -// NewLinear creates a dense Linear layer with optional bias. -func NewLinear(weight, bias *Array) *Linear { - return &Linear{Weight: weight, Bias: bias} -} - -// NewQuantizedLinear creates a quantized Linear layer. -func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear { - return &Linear{ - Weight: weight, - Scales: scales, - Biases: biases, - Bias: bias, - GroupSize: groupSize, - Bits: bits, - } -} - -// Forward computes the linear transformation. -// If a LoRA adapter is attached, routes through it instead (base + low-rank delta). -// Uses QuantizedMatmul when quantization parameters are present. -func (l *Linear) Forward(x *Array) *Array { - if l.LoRA != nil { - return l.LoRA.Forward(x) - } - return l.baseForward(x) -} - -// baseForward is the raw linear transformation without LoRA. -// Used internally by LoRALinear to avoid infinite recursion. -func (l *Linear) baseForward(x *Array) *Array { - var out *Array - if l.Scales != nil { - out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits) - } else { - out = Matmul(x, Transpose(l.Weight)) - } - if l.Bias != nil && l.Bias.Valid() { - out = Add(out, l.Bias) - } - return out -} - -// Embedding is a lookup table for token embeddings. -// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup. -type Embedding struct { - Weight *Array `weight:"weight"` - Scales *Array `weight:"scales"` - Biases *Array `weight:"biases"` - GroupSize int - Bits int -} - -// Forward looks up embeddings for the given token indices. -func (e *Embedding) Forward(indices *Array) *Array { - if e.Scales != nil { - w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits) - return Take(w, indices, 0) - } - return Take(e.Weight, indices, 0) -} - -// AsLinear returns a Linear layer using the embedding weights (for tied output). -func (e *Embedding) AsLinear() *Linear { - return &Linear{ - Weight: e.Weight, - Scales: e.Scales, - Biases: e.Biases, - GroupSize: e.GroupSize, - Bits: e.Bits, - } -} - -// RMSNormModule is an RMS normalization layer wrapping the fused kernel. -type RMSNormModule struct { - Weight *Array `weight:"weight"` -} - -// Forward applies RMS normalization. -func (r *RMSNormModule) Forward(x *Array, eps float32) *Array { - return RMSNorm(x, r.Weight, eps) -} - -// RepeatKV repeats key/value heads for grouped-query attention. -// Input shape: [B, num_kv_heads, L, D] -// Output shape: [B, num_kv_heads * factor, L, D] -func RepeatKV(x *Array, factor int32) *Array { - if factor <= 1 { - return x - } - shape := x.Shape() - B, H, L, D := shape[0], shape[1], shape[2], shape[3] - - // Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D] - expanded := ExpandDims(x, 2) - expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D}) - return Reshape(expanded, B, H*factor, L, D) -} diff --git a/mlx/ops.go b/mlx/ops.go deleted file mode 100644 index 99f953b..0000000 --- a/mlx/ops.go +++ /dev/null @@ -1,353 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include -#include "mlx/c/mlx.h" -*/ -import "C" - -import "unsafe" - -// --- Element-wise arithmetic --- - -// Add returns element-wise a + b. -func Add(a, b *Array) *Array { - out := New("ADD", a, b) - C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// AddScalar returns a + scalar (broadcast). -func AddScalar(a *Array, s float32) *Array { - scalar := FromValue(s) - return Add(a, scalar) -} - -// Mul returns element-wise a * b. -func Mul(a, b *Array) *Array { - out := New("MUL", a, b) - C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// MulScalar returns a * scalar (broadcast). -func MulScalar(a *Array, s float32) *Array { - scalar := FromValue(s) - return Mul(a, scalar) -} - -// Divide returns element-wise a / b. -func Divide(a, b *Array) *Array { - out := New("DIV", a, b) - C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// Subtract returns element-wise a - b. -func Subtract(a, b *Array) *Array { - out := New("SUB", a, b) - C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// Negative returns element-wise -a. -func Negative(a *Array) *Array { - out := New("NEG", a) - C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// --- Math functions --- - -// Exp returns element-wise exp(a). -func Exp(a *Array) *Array { - out := New("EXP", a) - C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Sigmoid returns element-wise 1/(1+exp(-a)). -func Sigmoid(a *Array) *Array { - out := New("SIGMOID", a) - C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// SiLU returns element-wise x * sigmoid(x) (Swish activation). -func SiLU(a *Array) *Array { - return Mul(a, Sigmoid(a)) -} - -// Tanh returns element-wise tanh(a). -func Tanh(a *Array) *Array { - out := New("TANH", a) - C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Sqrt returns element-wise sqrt(a). -func Sqrt(a *Array) *Array { - out := New("SQRT", a) - C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Rsqrt returns element-wise 1/sqrt(a). -func Rsqrt(a *Array) *Array { - out := New("RSQRT", a) - C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Reciprocal returns element-wise 1/a. -func Reciprocal(a *Array) *Array { - out := New("RECIPROCAL", a) - C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Square returns element-wise a^2. -func Square(a *Array) *Array { - out := New("SQUARE", a) - C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx) - return out -} - -// Power returns element-wise a^b. -func Power(a, b *Array) *Array { - out := New("POWER", a, b) - C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// Maximum returns element-wise max(a, b). -func Maximum(a, b *Array) *Array { - out := New("MAX", a, b) - C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// Minimum returns element-wise min(a, b). -func Minimum(a, b *Array) *Array { - out := New("MIN", a, b) - C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// --- Matrix operations --- - -// Matmul returns the matrix product of a and b. -func Matmul(a, b *Array) *Array { - out := New("MATMUL", a, b) - C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// QuantizedMatmul performs quantized matrix multiplication. -func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array { - out := New("QMATMUL", x, w, scales, biases) - gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} - b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} - mode := C.CString("affine") - defer C.free(unsafe.Pointer(mode)) - C.mlx_quantized_matmul( - &out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx, - C._Bool(transpose), gs, b, mode, - DefaultStream().ctx, - ) - return out -} - -// --- Reductions --- - -// Softmax returns softmax along the last axis. -func Softmax(a *Array) *Array { - out := New("SOFTMAX", a) - axis := []C.int{C.int(-1)} - C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx) - return out -} - -// Argmax returns the index of the maximum value along an axis. -func Argmax(a *Array, axis int, keepDims bool) *Array { - out := New("ARGMAX", a) - C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) - return out -} - -// TopK returns the top k values along the last axis. -func TopK(a *Array, k int) *Array { - out := New("TOPK", a) - C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx) - return out -} - -// Sum reduces by summation along the given axis. -func Sum(a *Array, axis int, keepDims bool) *Array { - out := New("SUM", a) - axes := []C.int{C.int(axis)} - C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) - return out -} - -// Mean reduces by averaging along the given axis. -func Mean(a *Array, axis int, keepDims bool) *Array { - out := New("MEAN", a) - axes := []C.int{C.int(axis)} - C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx) - return out -} - -// --- Shape operations --- - -// Reshape changes the shape of an array. -func Reshape(a *Array, shape ...int32) *Array { - out := New("RESHAPE", a) - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) - return out -} - -// Transpose permutes dimensions. If no axes given, reverses all dims. -func Transpose(a *Array, axes ...int) *Array { - out := New("TRANSPOSE", a) - if len(axes) == 0 { - C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx) - } else { - cAxes := make([]C.int, len(axes)) - for i, ax := range axes { - cAxes[i] = C.int(ax) - } - C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) - } - return out -} - -// ExpandDims inserts a new axis at the given position. -func ExpandDims(a *Array, axis int) *Array { - out := New("EXPAND_DIMS", a) - C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx) - return out -} - -// Squeeze removes dimensions of size 1. -func Squeeze(a *Array, axes ...int) *Array { - out := New("SQUEEZE", a) - cAxes := make([]C.int, len(axes)) - for i, ax := range axes { - cAxes[i] = C.int(ax) - } - C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx) - return out -} - -// Concatenate joins arrays along the given axis. -func Concatenate(arrays []*Array, axis int) *Array { - vector := C.mlx_vector_array_new() - defer C.mlx_vector_array_free(vector) - - inputs := make([]*Array, len(arrays)) - for i, a := range arrays { - C.mlx_vector_array_append_value(vector, a.ctx) - inputs[i] = a - } - - out := New("CONCAT", inputs...) - C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx) - return out -} - -// BroadcastTo broadcasts an array to the given shape. -func BroadcastTo(a *Array, shape []int32) *Array { - out := New("BROADCAST", a) - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx) - return out -} - -// AsType casts an array to a different dtype. -func AsType(a *Array, dtype DType) *Array { - out := New("ASTYPE", a) - C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx) - return out -} - -// AsStrided creates a view with custom strides. -func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array { - out := New("AS_STRIDED", a) - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - cStrides := make([]C.int64_t, len(strides)) - for i, s := range strides { - cStrides[i] = C.int64_t(s) - } - C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx) - return out -} - -// Take gathers elements from a along axis using indices. -func Take(a, indices *Array, axis int) *Array { - out := New("TAKE", a, indices) - C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) - return out -} - -// Where selects elements from a or b based on condition. -func Where(condition, a, b *Array) *Array { - out := New("WHERE", condition, a, b) - C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx) - return out -} - -// Argpartition partially sorts and returns indices for top-k selection. -func Argpartition(a *Array, kth, axis int) *Array { - out := New("ARGPARTITION", a) - C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx) - return out -} - -// Dequantize restores a quantized array to full precision. -func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array { - out := New("DEQUANTIZE", w, scales, biases) - gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)} - b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)} - mode := C.CString("affine") - defer C.free(unsafe.Pointer(mode)) - noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)} - C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx) - return out -} - -// PutAlongAxis places values into array at indices along axis. -func PutAlongAxis(a, indices, values *Array, axis int) *Array { - out := New("PUT_ALONG_AXIS", a, indices, values) - // Use scatter approach: src[indices] = values - C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx) - return out -} - -// TakeAlongAxis gathers elements from a along axis using indices. -// Unlike Take, this uses the same number of dimensions for indices and input. -func TakeAlongAxis(a, indices *Array, axis int) *Array { - out := New("TAKE_ALONG_AXIS", a, indices) - C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx) - return out -} - -// LogSumExp computes log(sum(exp(a))) along the given axis. -// Numerically stable reduction for cross-entropy loss. -func LogSumExp(a *Array, axis int, keepDims bool) *Array { - out := New("LOGSUMEXP", a) - C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx) - return out -} diff --git a/mlx/optim.go b/mlx/optim.go deleted file mode 100644 index f54bcf5..0000000 --- a/mlx/optim.go +++ /dev/null @@ -1,106 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -import "math" - -// AdamW implements the AdamW optimiser (Adam with decoupled weight decay). -// -// Update rule per parameter: -// -// m = beta1 * m + (1 - beta1) * grad -// v = beta2 * v + (1 - beta2) * grad^2 -// m_hat = m / (1 - beta1^t) -// v_hat = v / (1 - beta2^t) -// param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps) -type AdamW struct { - LR float64 // Learning rate (default 1e-5) - Beta1 float64 // First moment decay (default 0.9) - Beta2 float64 // Second moment decay (default 0.999) - Eps float64 // Numerical stability (default 1e-8) - WeightDecay float64 // Decoupled weight decay (default 0.01) - - step int // Number of updates performed - m []*Array // First moment estimates (positional, parallel to params) - v []*Array // Second moment estimates (positional, parallel to params) -} - -// NewAdamW creates an AdamW optimiser with default hyperparameters. -func NewAdamW(lr float64) *AdamW { - return &AdamW{ - LR: lr, - Beta1: 0.9, - Beta2: 0.999, - Eps: 1e-8, - WeightDecay: 0.01, - } -} - -// Step performs one optimisation step: updates params using gradients. -// params and grads must be parallel slices of the same length. -// Returns the updated parameter arrays (params are replaced in-place). -func (o *AdamW) Step(params []*Array, grads []*Array) []*Array { - o.step++ - - // Bias correction factors - bc1 := 1.0 - math.Pow(o.Beta1, float64(o.step)) - bc2 := 1.0 - math.Pow(o.Beta2, float64(o.step)) - - updated := make([]*Array, len(params)) - - // Grow moment slices if needed (first call or param count increased) - for len(o.m) < len(params) { - o.m = append(o.m, nil) - o.v = append(o.v, nil) - } - - for i, param := range params { - grad := grads[i] - - // Initialise moments on first use - if o.m[i] == nil { - shape := param.Shape() - o.m[i] = Zeros(shape, param.Dtype()) - o.v[i] = Zeros(shape, param.Dtype()) - } - - // m = beta1 * m + (1 - beta1) * grad - m := Add( - MulScalar(o.m[i], float32(o.Beta1)), - MulScalar(grad, float32(1.0-o.Beta1)), - ) - - // v = beta2 * v + (1 - beta2) * grad^2 - v := Add( - MulScalar(o.v[i], float32(o.Beta2)), - MulScalar(Square(grad), float32(1.0-o.Beta2)), - ) - - // Bias-corrected estimates - mHat := MulScalar(m, float32(1.0/bc1)) - vHat := MulScalar(v, float32(1.0/bc2)) - - // Weight decay: param = param * (1 - lr * weight_decay) - decayed := MulScalar(param, float32(1.0-o.LR*o.WeightDecay)) - - // Update: param = decayed - lr * m_hat / (sqrt(v_hat) + eps) - denom := AddScalar(Sqrt(vHat), float32(o.Eps)) - step := MulScalar(Divide(mHat, denom), float32(o.LR)) - newParam := Subtract(decayed, step) - - // Store updated moments - o.m[i] = m - o.v[i] = v - - updated[i] = newParam - } - - return updated -} - -// Reset clears the optimiser state (moments and step counter). -func (o *AdamW) Reset() { - o.step = 0 - o.m = nil - o.v = nil -} diff --git a/mlx/optim_test.go b/mlx/optim_test.go deleted file mode 100644 index b2df911..0000000 --- a/mlx/optim_test.go +++ /dev/null @@ -1,175 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -import ( - "math" - "testing" -) - -func TestAdamW_BasicStep(t *testing.T) { - // Simple test: minimise f(x) = x^2, starting at x=10 - x := FromValue(float32(10.0)) - Materialize(x) - - opt := NewAdamW(0.1) - - for i := 0; i < 300; i++ { - // Gradient of x^2 is 2x - lossFn := func(inputs []*Array) []*Array { - p := inputs[0] - return []*Array{Mul(p, p)} - } - - grad := ValueAndGrad(lossFn) - _, grads, err := grad.Apply(x) - grad.Free() - if err != nil { - t.Fatalf("step %d: grad failed: %v", i, err) - } - - updated := opt.Step([]*Array{x}, grads) - x = updated[0] - Materialize(x) - } - - final := x.Float() - if math.Abs(final) > 0.5 { - t.Errorf("after 300 steps, x = %f, want near 0", final) - } - t.Logf("final x = %f (started at 10.0)", final) -} - -func TestAdamW_MultiParam(t *testing.T) { - // Minimise f(x, y) = x^2 + y^2 - x := FromValue(float32(5.0)) - y := FromValue(float32(-3.0)) - Materialize(x, y) - - opt := NewAdamW(0.1) - - for i := 0; i < 100; i++ { - lossFn := func(inputs []*Array) []*Array { - return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))} - } - - grad := ValueAndGrad(lossFn, 0, 1) - _, grads, err := grad.Apply(x, y) - grad.Free() - if err != nil { - t.Fatalf("step %d failed: %v", i, err) - } - - updated := opt.Step([]*Array{x, y}, grads) - x = updated[0] - y = updated[1] - Materialize(x, y) - } - - xFinal := x.Float() - yFinal := y.Float() - if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 { - t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal) - } - t.Logf("final x=%f, y=%f", xFinal, yFinal) -} - -func TestAdamW_WeightDecay(t *testing.T) { - // With large weight decay and zero gradient, param should decay toward 0 - x := FromValue(float32(10.0)) - Materialize(x) - - opt := NewAdamW(0.01) - opt.WeightDecay = 0.5 // aggressive decay - - zeroGrad := FromValue(float32(0.0)) - Materialize(zeroGrad) - - for i := 0; i < 10; i++ { - updated := opt.Step([]*Array{x}, []*Array{zeroGrad}) - x = updated[0] - Materialize(x) - } - - final := x.Float() - if final >= 10.0 { - t.Errorf("x = %f, should have decayed from 10.0", final) - } - if final <= 0 { - t.Errorf("x = %f, decayed too much", final) - } - t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final) -} - -func TestAdamW_Reset(t *testing.T) { - opt := NewAdamW(0.01) - - x := FromValue(float32(5.0)) - grad := FromValue(float32(1.0)) - Materialize(x, grad) - - opt.Step([]*Array{x}, []*Array{grad}) - if opt.step != 1 { - t.Errorf("step = %d, want 1", opt.step) - } - - opt.Reset() - if opt.step != 0 { - t.Errorf("after reset, step = %d, want 0", opt.step) - } - if opt.m != nil { - t.Error("after reset, moments should be nil") - } -} - -func TestAdamW_WithLoRA(t *testing.T) { - // End-to-end: create LoRA layer, compute gradients, update with AdamW - w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32) - Materialize(w) - base := NewLinear(w, nil) - - lora := NewLoRALinear(base, 4, 8.0) - opt := NewAdamW(0.001) - - x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32) - target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32) - Materialize(x, target) - - var initialLoss, finalLoss float64 - - for step := 0; step < 50; step++ { - lossFn := func(inputs []*Array) []*Array { - lora.A = inputs[0] - lora.B = inputs[1] - pred := lora.Forward(x) - return []*Array{MSELoss(pred, target)} - } - - grad := ValueAndGrad(lossFn, 0, 1) - values, grads, err := grad.Apply(lora.A, lora.B) - grad.Free() - if err != nil { - t.Fatalf("step %d failed: %v", step, err) - } - - Materialize(append(values, grads...)...) - - loss := values[0].Float() - if step == 0 { - initialLoss = loss - } - if step == 49 { - finalLoss = loss - } - - updated := opt.Step([]*Array{lora.A, lora.B}, grads) - lora.A = updated[0] - lora.B = updated[1] - Materialize(lora.A, lora.B) - } - - t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss) - if finalLoss >= initialLoss { - t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss) - } -} diff --git a/mlx/random.go b/mlx/random.go deleted file mode 100644 index f7e09b2..0000000 --- a/mlx/random.go +++ /dev/null @@ -1,46 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include "mlx/c/mlx.h" -*/ -import "C" - -// RandomCategorical samples from a categorical distribution defined by logprobs. -// Returns indices sampled according to the log-probability distribution along the last axis. -func RandomCategorical(logprobs *Array) *Array { - out := New("RANDOM_CATEGORICAL", logprobs) - key := C.mlx_array_new() - defer C.mlx_array_free(key) - C.mlx_random_categorical( - &out.ctx, - logprobs.ctx, - C.int(-1), // axis - key, // null key = use default RNG - DefaultStream().ctx, - ) - return out -} - -// RandomUniform generates uniform random values in [low, high). -func RandomUniform(low, high float32, shape []int32, dtype DType) *Array { - out := New("RANDOM_UNIFORM") - cShape := make([]C.int, len(shape)) - for i, s := range shape { - cShape[i] = C.int(s) - } - lo := FromValue(low) - hi := FromValue(high) - key := C.mlx_array_new() - defer C.mlx_array_free(key) - C.mlx_random_uniform( - &out.ctx, - lo.ctx, hi.ctx, - &cShape[0], C.size_t(len(cShape)), - C.mlx_dtype(dtype), - key, - DefaultStream().ctx, - ) - return out -} diff --git a/mlx/sample/sample.go b/mlx/sample/sample.go deleted file mode 100644 index 8a49c04..0000000 --- a/mlx/sample/sample.go +++ /dev/null @@ -1,90 +0,0 @@ -//go:build darwin && arm64 - -// Package sample provides composable token sampling strategies. -package sample - -import ( - "math" - - "forge.lthn.ai/core/go-ai/mlx" -) - -// Sampler transforms logits into a sampled token index. -type Sampler interface { - Sample(logits *mlx.Array) *mlx.Array -} - -// New creates a composable sampler chain from the given parameters. -// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample. -func New(temp, topP, minP float32, topK int) Sampler { - if temp == 0 { - return greedy{} - } - - var samplers []Sampler - if topP > 0 && topP < 1 { - samplers = append(samplers, TopP(topP)) - } - if minP > 0 { - samplers = append(samplers, MinPSampler(minP)) - } - if topK > 0 { - samplers = append(samplers, TopKSampler(topK)) - } - samplers = append(samplers, Temperature(temp)) - return chain(samplers) -} - -// chain applies a sequence of samplers, then samples from the result. -type chain []Sampler - -func (c chain) Sample(logits *mlx.Array) *mlx.Array { - for _, s := range c { - logits = s.Sample(logits) - } - // Final categorical sample from log-probabilities - return mlx.RandomCategorical(logits) -} - -// greedy returns the argmax token. -type greedy struct{} - -func (greedy) Sample(logits *mlx.Array) *mlx.Array { - return mlx.Argmax(logits, -1, false) -} - -// Temperature scales logits by 1/temp. -type Temperature float32 - -func (t Temperature) Sample(logits *mlx.Array) *mlx.Array { - return mlx.MulScalar(logits, 1.0/float32(t)) -} - -// TopKSampler masks all but the top-k logits. -type TopKSampler int - -func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array { - neg := mlx.Negative(logits) - mask := mlx.Argpartition(neg, int(k)-1, -1) - // Slice the indices beyond top-k - mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1))) - return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1) -} - -// TopP implements nucleus sampling (cumulative probability threshold). -type TopP float32 - -func (p TopP) Sample(logits *mlx.Array) *mlx.Array { - // TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly. - // For now, pass through. TopK + Temperature covers most use cases. - return logits -} - -// MinPSampler masks tokens below min_p * max_prob. -type MinPSampler float32 - -func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array { - // For now, pass through — MinP is an optimization over TopP. - // Full implementation requires finding max prob and masking below threshold. - return logits -} diff --git a/mlx/slice.go b/mlx/slice.go deleted file mode 100644 index 5bb7a66..0000000 --- a/mlx/slice.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include "mlx/c/mlx.h" -*/ -import "C" - -// Slice extracts a sub-array using start and end indices for each dimension. -// starts and ends must have the same length as the array's dimensions. -func Slice(a *Array, starts, ends []int32) *Array { - out := New("SLICE", a) - cStarts := make([]C.int, len(starts)) - cEnds := make([]C.int, len(ends)) - for i := range starts { - cStarts[i] = C.int(starts[i]) - cEnds[i] = C.int(ends[i]) - } - strides := make([]C.int, len(starts)) - for i := range strides { - strides[i] = 1 - } - C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) - return out -} - -// SliceAxis extracts a sub-array along a single axis. -func SliceAxis(a *Array, axis int, start, end int32) *Array { - // Build full slice parameters - ndim := a.NumDims() - starts := make([]int32, ndim) - ends := make([]int32, ndim) - for i := 0; i < ndim; i++ { - starts[i] = 0 - ends[i] = int32(a.Dim(i)) - } - ax := axis - if ax < 0 { - ax = ndim + ax - } - starts[ax] = start - ends[ax] = end - return Slice(a, starts, ends) -} - -// SliceUpdateInplace updates a slice of the array in-place. -// This is critical for KV cache updates. -func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array { - out := New("SLICE_UPDATE", a, update) - cStarts := make([]C.int, len(starts)) - cEnds := make([]C.int, len(ends)) - for i := range starts { - cStarts[i] = C.int(starts[i]) - cEnds[i] = C.int(ends[i]) - } - strides := make([]C.int, len(starts)) - for i := range strides { - strides[i] = 1 - } - C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx) - return out -} diff --git a/mlx/stream.go b/mlx/stream.go deleted file mode 100644 index 248a224..0000000 --- a/mlx/stream.go +++ /dev/null @@ -1,79 +0,0 @@ -//go:build darwin && arm64 - -package mlx - -/* -#include "mlx/c/mlx.h" -*/ -import "C" - -import "sync" - -// Stream wraps an mlx_stream handle for dispatching operations. -type Stream struct { - ctx C.mlx_stream -} - -var ( - defaultStream *Stream - defaultStreamOnce sync.Once -) - -// DefaultStream returns the default GPU stream, creating it on first use. -func DefaultStream() *Stream { - defaultStreamOnce.Do(func() { - Init() - defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()} - }) - return defaultStream -} - -// DefaultGPUStream returns a new GPU stream. -func DefaultGPUStream() *Stream { - Init() - return &Stream{ctx: C.mlx_default_gpu_stream_new()} -} - -// DefaultCPUStream returns a new CPU stream. -func DefaultCPUStream() *Stream { - Init() - return &Stream{ctx: C.mlx_default_cpu_stream_new()} -} - -// Synchronize waits for all operations on the stream to complete. -func Synchronize(s *Stream) { - C.mlx_synchronize(s.ctx) -} - -// SetMemoryLimit sets the Metal memory limit. Returns the previous limit. -func SetMemoryLimit(limit uint64) uint64 { - var prev C.size_t - C.mlx_set_memory_limit(&prev, C.size_t(limit)) - return uint64(prev) -} - -// SetCacheLimit sets the Metal cache limit. Returns the previous limit. -func SetCacheLimit(limit uint64) uint64 { - var prev C.size_t - C.mlx_set_cache_limit(&prev, C.size_t(limit)) - return uint64(prev) -} - -// GetActiveMemory returns the current Metal memory usage in bytes. -func GetActiveMemory() uint64 { - var mem C.size_t - C.mlx_get_active_memory(&mem) - return uint64(mem) -} - -// GetPeakMemory returns the peak Metal memory usage in bytes. -func GetPeakMemory() uint64 { - var mem C.size_t - C.mlx_get_peak_memory(&mem) - return uint64(mem) -} - -// ClearCache releases Metal memory held in the MLX allocator cache. -func ClearCache() { - C.mlx_clear_cache() -} diff --git a/mlx/tokenizer/tokenizer.go b/mlx/tokenizer/tokenizer.go deleted file mode 100644 index 947a1a5..0000000 --- a/mlx/tokenizer/tokenizer.go +++ /dev/null @@ -1,324 +0,0 @@ -//go:build darwin && arm64 - -// Package tokenizer provides BPE tokenization for transformer models. -package tokenizer - -import ( - "encoding/json" - "fmt" - "os" - "strings" -) - -// Tokenizer handles text-to-token and token-to-text conversion. -type Tokenizer struct { - vocab map[string]int32 - invVocab map[int32]string - merges []mergePair - special map[string]int32 - - bosToken int32 - eosToken int32 - - // GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.) - isGPT2BPE bool - gpt2Decoder map[rune]byte // Unicode char → original byte - gpt2Encoder map[byte]rune // original byte → Unicode char -} - -type mergePair struct { - a, b string - rank int -} - -// tokenizerJSON is the HuggingFace tokenizer.json format. -type tokenizerJSON struct { - Model struct { - Type string `json:"type"` - Vocab json.RawMessage `json:"vocab"` - Merges json.RawMessage `json:"merges"` - ByteFallback bool `json:"byte_fallback"` - } `json:"model"` - AddedTokens []struct { - ID int32 `json:"id"` - Content string `json:"content"` - Special bool `json:"special"` - } `json:"added_tokens"` -} - -// Load reads a tokenizer.json file and creates a Tokenizer. -func Load(path string) (*Tokenizer, error) { - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("tokenizer: read %s: %w", path, err) - } - - var tj tokenizerJSON - if err := json.Unmarshal(data, &tj); err != nil { - return nil, fmt.Errorf("tokenizer: parse: %w", err) - } - - t := &Tokenizer{ - vocab: make(map[string]int32), - invVocab: make(map[int32]string), - special: make(map[string]int32), - } - - // Parse vocab - var vocab map[string]int32 - if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil { - return nil, fmt.Errorf("tokenizer: parse vocab: %w", err) - } - t.vocab = vocab - for k, v := range vocab { - t.invVocab[v] = k - } - - // Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats - if len(tj.Model.Merges) > 0 { - var stringMerges []string - if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil { - for rank, merge := range stringMerges { - parts := strings.SplitN(merge, " ", 2) - if len(parts) == 2 { - t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank}) - } - } - } else { - var arrayMerges [][]string - if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil { - for rank, pair := range arrayMerges { - if len(pair) == 2 { - t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank}) - } - } - } - } - } - - // Parse special tokens - for _, tok := range tj.AddedTokens { - if tok.Special { - t.special[tok.Content] = tok.ID - } - t.vocab[tok.Content] = tok.ID - t.invVocab[tok.ID] = tok.Content - } - - // Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space) - if _, ok := t.vocab["Ġ"]; ok { - t.isGPT2BPE = true - t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps() - } - - // Set BOS/EOS — detect model family from special tokens - if id, ok := t.special[""]; ok { - t.bosToken = id - } - if id, ok := t.special[""]; ok { - t.eosToken = id - } - // Gemma: is the generation stop token - if id, ok := t.special[""]; ok { - t.eosToken = id - } - // Qwen3: <|im_end|> is the generation stop token - if id, ok := t.special["<|im_end|>"]; ok { - t.eosToken = id - } - // Qwen3 BOS: <|im_start|> - if id, ok := t.special["<|im_start|>"]; ok { - t.bosToken = id - } - - return t, nil -} - -// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps. -// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars -// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves; -// everything else (0-32, 127-160, 173) maps to U+0100 onwards. -func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) { - encoder = make(map[byte]rune, 256) - decoder = make(map[rune]byte, 256) - - // Self-mapping ranges: printable ASCII + Latin-1 Supplement - // Use int loop variable to avoid byte overflow at 255. - selfMap := func(lo, hi int) { - for b := lo; b <= hi; b++ { - encoder[byte(b)] = rune(b) - decoder[rune(b)] = byte(b) - } - } - selfMap(33, 126) // ! through ~ - selfMap(161, 172) // ¡ through ¬ - selfMap(174, 255) // ® through ÿ - - // Non-self-mapping: control chars, space, DEL, and gaps - n := 0 - for b := 0; b < 256; b++ { - if _, ok := encoder[byte(b)]; !ok { - r := rune(256 + n) - encoder[byte(b)] = r - decoder[r] = byte(b) - n++ - } - } - return -} - -// Encode converts text to token IDs. Prepends BOS token. -func (t *Tokenizer) Encode(text string) []int32 { - tokens := []int32{t.bosToken} - - if t.isGPT2BPE { - return t.encodeGPT2(text) - } - - // SentencePiece style encoding - remaining := text - for remaining != "" { - found := false - for tok, id := range t.special { - if strings.HasPrefix(remaining, tok) { - tokens = append(tokens, id) - remaining = remaining[len(tok):] - found = true - break - } - } - if !found { - r := []rune(remaining) - ch := "▁" + string(r[0]) - if id, ok := t.vocab[ch]; ok { - tokens = append(tokens, id) - } else if id, ok := t.vocab[string(r[0])]; ok { - tokens = append(tokens, id) - } - remaining = string(r[1:]) - } - } - - return tokens -} - -// encodeGPT2 encodes text using GPT-2 byte-level BPE. -func (t *Tokenizer) encodeGPT2(text string) []int32 { - tokens := []int32{t.bosToken} - - // Convert text bytes to GPT-2 Unicode representation - var encoded strings.Builder - for _, b := range []byte(text) { - if r, ok := t.gpt2Encoder[b]; ok { - encoded.WriteRune(r) - } - } - gpt2Text := encoded.String() - - // Scan for special tokens and regular text - remaining := gpt2Text - for remaining != "" { - // Check special tokens (these are stored as-is, not byte-encoded) - found := false - for tok, id := range t.special { - // Special tokens in GPT-2 tokenizers are stored in their original form - // Convert the special token to GPT-2 encoding for matching - var encTok strings.Builder - for _, b := range []byte(tok) { - if r, ok := t.gpt2Encoder[b]; ok { - encTok.WriteRune(r) - } - } - encStr := encTok.String() - if strings.HasPrefix(remaining, encStr) { - tokens = append(tokens, id) - remaining = remaining[len(encStr):] - found = true - break - } - } - if !found { - // Character-by-character lookup (simplified BPE) - r := []rune(remaining) - ch := string(r[0]) - if id, ok := t.vocab[ch]; ok { - tokens = append(tokens, id) - } - remaining = string(r[1:]) - } - } - - return tokens -} - -// Decode converts token IDs back to text. -// For full-sequence decoding, the SentencePiece leading space is stripped. -func (t *Tokenizer) Decode(tokens []int32) string { - var sb strings.Builder - for _, id := range tokens { - if text, ok := t.invVocab[id]; ok { - // Skip special tokens in decode output - if _, isSpecial := t.special[text]; isSpecial { - continue - } - sb.WriteString(text) - } - } - raw := sb.String() - - if t.isGPT2BPE { - return t.decodeGPT2Bytes(raw) - } - - // SentencePiece style - result := strings.ReplaceAll(raw, "▁", " ") - if strings.HasPrefix(result, " ") { - result = result[1:] - } - return result -} - -// DecodeToken converts a single token ID to text for streaming. -// Unlike Decode, it preserves the leading space (word boundary) so that -// token-by-token output maintains correct spacing between words. -func (t *Tokenizer) DecodeToken(id int32) string { - text, ok := t.invVocab[id] - if !ok { - return "" - } - if _, isSpecial := t.special[text]; isSpecial { - return "" - } - - if t.isGPT2BPE { - return t.decodeGPT2Bytes(text) - } - - // SentencePiece: replace ▁ with space but keep it (it's the word boundary) - return strings.ReplaceAll(text, "▁", " ") -} - -// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes. -func (t *Tokenizer) decodeGPT2Bytes(s string) string { - var buf []byte - for _, r := range s { - if b, ok := t.gpt2Decoder[r]; ok { - buf = append(buf, b) - } else { - // Non-mapped runes pass through as UTF-8 - buf = append(buf, []byte(string(r))...) - } - } - return string(buf) -} - -// BOSToken returns the beginning-of-sequence token ID. -func (t *Tokenizer) BOSToken() int32 { return t.bosToken } - -// EOSToken returns the end-of-sequence token ID. -func (t *Tokenizer) EOSToken() int32 { return t.eosToken } - -// FormatGemmaPrompt applies the Gemma 3 chat template. -func FormatGemmaPrompt(prompt string) string { - return fmt.Sprintf("user\n%s\nmodel\n", prompt) -}