refactor: extract mlx/ to standalone core/go-mlx module
Remove mlx/ directory (now lives at forge.lthn.ai/core/go-mlx). Update ml/backend_mlx.go imports to reference the new module. Add replace directive for local development. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
906a535899
commit
34d0f9ce41
27 changed files with 8 additions and 4387 deletions
3
go.mod
3
go.mod
|
|
@ -4,6 +4,7 @@ go 1.25.5
|
|||
|
||||
require (
|
||||
forge.lthn.ai/core/go v0.0.0
|
||||
forge.lthn.ai/core/go-mlx v0.0.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/marcboeker/go-duckdb v1.8.5
|
||||
github.com/modelcontextprotocol/go-sdk v1.3.0
|
||||
|
|
@ -54,3 +55,5 @@ require (
|
|||
)
|
||||
|
||||
replace forge.lthn.ai/core/go => ../core
|
||||
|
||||
replace forge.lthn.ai/core/go-mlx => ../go-mlx
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/model"
|
||||
"forge.lthn.ai/core/go-ai/mlx/sample"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
"forge.lthn.ai/core/go-mlx"
|
||||
"forge.lthn.ai/core/go-mlx/cache"
|
||||
"forge.lthn.ai/core/go-mlx/model"
|
||||
"forge.lthn.ai/core/go-mlx/sample"
|
||||
"forge.lthn.ai/core/go-mlx/tokenizer"
|
||||
)
|
||||
|
||||
// MLXBackend implements Backend and StreamingBackend for native Metal inference.
|
||||
|
|
|
|||
|
|
@ -1,28 +0,0 @@
|
|||
cmake_minimum_required(VERSION 3.24)
|
||||
|
||||
project(mlx)
|
||||
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "26.0" CACHE STRING "Minimum macOS version")
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
261
mlx/array.go
261
mlx/array.go
|
|
@ -1,261 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Array wraps an mlx_array handle.
|
||||
// Memory management relies on Go GC finalizers to call mlx_array_free,
|
||||
// which decrements MLX-C's internal reference count. MLX-C handles all
|
||||
// cross-array references internally — the Go wrapper does not track them.
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
name string // debug label
|
||||
}
|
||||
|
||||
// New creates a named Array and registers a GC finalizer.
|
||||
// The inputs parameter is accepted for API compatibility but not stored —
|
||||
// MLX-C tracks inter-array references via its own refcounting.
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{name: name}
|
||||
runtime.SetFinalizer(t, finalizeArray)
|
||||
return t
|
||||
}
|
||||
|
||||
// finalizeArray is called by Go GC to release the underlying C array handle.
|
||||
func finalizeArray(t *Array) {
|
||||
if t != nil && t.ctx.ctx != nil {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
// FromValue creates a scalar Array from a Go value.
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
Init()
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("mlx: unsupported scalar type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
// FromValues creates an Array from a Go slice with the given shape.
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
Init()
|
||||
if len(shape) == 0 {
|
||||
panic("mlx: shape required for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("mlx: unsupported element type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
// Zeros creates a zero-filled Array with the given shape and dtype.
|
||||
func Zeros(shape []int32, dtype DType) *Array {
|
||||
Init()
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
tt := New("ZEROS")
|
||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.size_t(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// Set replaces this array's C handle with another's.
|
||||
func (t *Array) Set(other *Array) {
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
// Clone creates a new Go wrapper sharing the same C handle (increments C refcount).
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.name)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// Valid reports whether this Array has a non-nil mlx handle.
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the array.
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
// Shape returns the dimensions as int32 slice.
|
||||
func (t *Array) Shape() []int32 {
|
||||
dims := make([]int32, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = int32(t.Dim(i))
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
// Size returns the total number of elements.
|
||||
func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) }
|
||||
|
||||
// NumBytes returns the total byte size.
|
||||
func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) }
|
||||
|
||||
// NumDims returns the number of dimensions.
|
||||
func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) }
|
||||
|
||||
// Dim returns the size of dimension i.
|
||||
func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) }
|
||||
|
||||
// Dims returns all dimensions as int slice.
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
// Dtype returns the array's data type.
|
||||
func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) }
|
||||
|
||||
// Int extracts a scalar int64 value.
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
// Float extracts a scalar float64 value.
|
||||
// Handles both float32 and float64 array dtypes.
|
||||
func (t Array) Float() float64 {
|
||||
switch t.Dtype() {
|
||||
case DTypeFloat32:
|
||||
var item C.float
|
||||
C.mlx_array_item_float32(&item, t.ctx)
|
||||
return float64(item)
|
||||
default:
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
}
|
||||
|
||||
// Ints extracts all elements as int slice (from int32 data).
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
// DataInt32 extracts all elements as int32 slice.
|
||||
func (t Array) DataInt32() []int32 {
|
||||
data := make([]int32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
|
||||
data[i] = int32(f)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// Floats extracts all elements as float32 slice.
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
// Free explicitly releases C array handles. Does not cascade — MLX-C's
|
||||
// internal refcounting handles dependent arrays automatically.
|
||||
func Free(s ...*Array) int {
|
||||
var n int
|
||||
for _, t := range s {
|
||||
if t != nil && t.Valid() {
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
runtime.SetFinalizer(t, nil) // cancel finalizer
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
201
mlx/cache/cache.go
vendored
201
mlx/cache/cache.go
vendored
|
|
@ -1,201 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package cache provides KV cache implementations for transformer inference.
|
||||
package cache
|
||||
|
||||
import "forge.lthn.ai/core/go-ai/mlx"
|
||||
|
||||
// Cache manages key-value pairs for transformer attention layers.
|
||||
type Cache interface {
|
||||
// Update adds new key/value tensors and returns the full cached K/V.
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
// Offset returns the total number of tokens processed.
|
||||
Offset() int
|
||||
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
|
||||
Len() int
|
||||
// State returns the cached K/V arrays, or nil if empty.
|
||||
State() []*mlx.Array
|
||||
// Reset clears the cache for a new generation session.
|
||||
Reset()
|
||||
}
|
||||
|
||||
// KVCache implements an unbounded cache that grows as needed.
|
||||
// Pre-allocates in chunks of `step` tokens to reduce allocations.
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
// NewKVCache creates a new unbounded KV cache with 256-token chunks.
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset += seqLen
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if needed.
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
func (c *KVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
// RotatingKVCache implements a bounded sliding window cache.
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
idx int
|
||||
}
|
||||
|
||||
// NewRotatingKVCache creates a cache bounded to maxSize tokens.
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset++
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
|
||||
var cap int
|
||||
if c.keys != nil {
|
||||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
if len(shape) < 4 {
|
||||
// K/V must be [B, H, L, D] — if not, pass through unchanged
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
}
|
||||
c.offset += seqLen
|
||||
return c.keys, c.values
|
||||
}
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
|
||||
func (c *RotatingKVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
// Callback for compiled functions.
|
||||
extern int goCompiledFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
|
||||
|
||||
static mlx_closure new_closure(void *payload) {
|
||||
return mlx_closure_new_func_payload(&goCompiledFunc, payload, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
|
||||
type CompiledFunc struct {
|
||||
fn func([]*Array) []*Array
|
||||
closure C.mlx_closure
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var compiledFuncs sync.Map
|
||||
|
||||
//export goCompiledFunc
|
||||
func goCompiledFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
id := uintptr(payload)
|
||||
fnI, ok := compiledFuncs.Load(id)
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
fn := fnI.(func([]*Array) []*Array)
|
||||
|
||||
// Convert inputs
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
// Call user function
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// The output vector arrives with ctx=nullptr. Create a valid vector,
|
||||
// fill it, then copy to the output pointer.
|
||||
tmp := C.mlx_vector_array_new()
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(tmp, out.ctx)
|
||||
}
|
||||
C.mlx_vector_array_set(outputs, tmp)
|
||||
C.mlx_vector_array_free(tmp)
|
||||
return 0
|
||||
}
|
||||
|
||||
var nextID uintptr
|
||||
var nextIDMu sync.Mutex
|
||||
|
||||
// CompileShapeless compiles a function for efficient repeated execution.
|
||||
// The function must accept and return arrays of consistent shapes.
|
||||
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
|
||||
nextIDMu.Lock()
|
||||
nextID++
|
||||
id := nextID
|
||||
nextIDMu.Unlock()
|
||||
|
||||
compiledFuncs.Store(id, fn)
|
||||
|
||||
cf := &CompiledFunc{fn: fn}
|
||||
cf.closure = C.new_closure(unsafe.Pointer(id))
|
||||
return cf
|
||||
}
|
||||
|
||||
// Call executes the compiled function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
cf.mu.Lock()
|
||||
defer cf.mu.Unlock()
|
||||
|
||||
// Fall back to direct call — compilation is an optimization.
|
||||
// The compiled closure can be used via mlx_compiled but the
|
||||
// direct path is simpler and still benefits from MLX's lazy evaluation.
|
||||
return cf.fn(inputs)
|
||||
}
|
||||
83
mlx/dtype.go
83
mlx/dtype.go
|
|
@ -1,83 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "mlx/c/mlx.h"
|
||||
import "C"
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// DType represents an MLX array data type.
|
||||
type DType C.mlx_dtype
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
|
||||
var dtypeNames = map[DType]string{
|
||||
DTypeBool: "bool",
|
||||
DTypeUint8: "uint8",
|
||||
DTypeUint16: "uint16",
|
||||
DTypeUint32: "uint32",
|
||||
DTypeUint64: "uint64",
|
||||
DTypeInt8: "int8",
|
||||
DTypeInt16: "int16",
|
||||
DTypeInt32: "int32",
|
||||
DTypeInt64: "int64",
|
||||
DTypeFloat16: "float16",
|
||||
DTypeFloat32: "float32",
|
||||
DTypeFloat64: "float64",
|
||||
DTypeBFloat16: "bfloat16",
|
||||
DTypeComplex64: "complex64",
|
||||
}
|
||||
|
||||
func (d DType) String() string {
|
||||
if s, ok := dtypeNames[d]; ok {
|
||||
return s
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var dtypeFromString = map[string]DType{
|
||||
"bool": DTypeBool, "BOOL": DTypeBool,
|
||||
"uint8": DTypeUint8, "U8": DTypeUint8,
|
||||
"uint16": DTypeUint16, "U16": DTypeUint16,
|
||||
"uint32": DTypeUint32, "U32": DTypeUint32,
|
||||
"uint64": DTypeUint64, "U64": DTypeUint64,
|
||||
"int8": DTypeInt8, "I8": DTypeInt8,
|
||||
"int16": DTypeInt16, "I16": DTypeInt16,
|
||||
"int32": DTypeInt32, "I32": DTypeInt32,
|
||||
"int64": DTypeInt64, "I64": DTypeInt64,
|
||||
"float16": DTypeFloat16, "F16": DTypeFloat16,
|
||||
"float32": DTypeFloat32, "F32": DTypeFloat32,
|
||||
"float64": DTypeFloat64, "F64": DTypeFloat64,
|
||||
"bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16,
|
||||
"complex64": DTypeComplex64,
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc.
|
||||
func (d *DType) UnmarshalJSON(b []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
if dt, ok := dtypeFromString[s]; ok {
|
||||
*d = dt
|
||||
return nil
|
||||
}
|
||||
*d = DTypeFloat32 // default
|
||||
return nil
|
||||
}
|
||||
79
mlx/fast.go
79
mlx/fast.go
|
|
@ -1,79 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
|
||||
func RMSNorm(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// LayerNorm applies Layer normalization using a fused Metal kernel.
|
||||
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
||||
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
||||
out := New("FAST_ROPE", x)
|
||||
freqs := C.mlx_array_new()
|
||||
defer C.mlx_array_free(freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C._Bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C._Bool(base != 0),
|
||||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
||||
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
||||
mode := ""
|
||||
if causal {
|
||||
mode = "causal"
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
maskArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(maskArr)
|
||||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, maskArr, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
||||
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
||||
cMode := C.CString("array")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
sinksArr := C.mlx_array_new()
|
||||
defer C.mlx_array_free(sinksArr)
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinksArr, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
353
mlx/grad.go
353
mlx/grad.go
|
|
@ -1,353 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
// Callback for gradient closures — same signature as goCompiledFunc.
|
||||
extern int goGradFunc(mlx_vector_array *outputs, const mlx_vector_array inputs, void *payload);
|
||||
|
||||
static mlx_closure new_grad_closure(void *payload) {
|
||||
return mlx_closure_new_func_payload(&goGradFunc, payload, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// --- Closure registry (separate from compile.go's registry) ---
|
||||
|
||||
var (
|
||||
gradFuncs sync.Map
|
||||
gradNextID atomic.Uintptr
|
||||
)
|
||||
|
||||
//export goGradFunc
|
||||
func goGradFunc(outputs *C.mlx_vector_array, inputs C.mlx_vector_array, payload unsafe.Pointer) C.int {
|
||||
id := uintptr(payload)
|
||||
fnI, ok := gradFuncs.Load(id)
|
||||
if !ok {
|
||||
return 1
|
||||
}
|
||||
fn := fnI.(func([]*Array) []*Array)
|
||||
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("GRAD_INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.size_t(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// The output vector arrives with ctx=nullptr (from internal mlx_vector_array_new_).
|
||||
// Create a valid vector, fill it, then copy to the output pointer.
|
||||
tmp := C.mlx_vector_array_new()
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(tmp, out.ctx)
|
||||
}
|
||||
C.mlx_vector_array_set(outputs, tmp)
|
||||
C.mlx_vector_array_free(tmp)
|
||||
return 0
|
||||
}
|
||||
|
||||
// newClosure registers a Go function as an MLX closure for autograd.
|
||||
func newClosure(fn func([]*Array) []*Array) C.mlx_closure {
|
||||
id := gradNextID.Add(1)
|
||||
gradFuncs.Store(id, fn)
|
||||
return C.new_grad_closure(unsafe.Pointer(id))
|
||||
}
|
||||
|
||||
// --- VJP (Vector-Jacobian Product) — Reverse Mode ---
|
||||
|
||||
// VJP computes the vector-Jacobian product (reverse-mode autodiff).
|
||||
// Given a function fn, input primals, and output cotangents (upstream gradients),
|
||||
// returns (outputs, gradients) where gradients are w.r.t. the primals.
|
||||
//
|
||||
// This is the fundamental backward pass operation.
|
||||
func VJP(fn func([]*Array) []*Array, primals []*Array, cotangents []*Array) (outputs []*Array, vjps []*Array, err error) {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
// Pack primals into vector
|
||||
primalsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(primalsVec)
|
||||
for _, p := range primals {
|
||||
C.mlx_vector_array_append_value(primalsVec, p.ctx)
|
||||
}
|
||||
|
||||
// Pack cotangents into vector
|
||||
cotangentsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(cotangentsVec)
|
||||
for _, c := range cotangents {
|
||||
C.mlx_vector_array_append_value(cotangentsVec, c.ctx)
|
||||
}
|
||||
|
||||
// Call mlx_vjp
|
||||
var outVec, vjpVec C.mlx_vector_array
|
||||
outVec = C.mlx_vector_array_new()
|
||||
vjpVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
defer C.mlx_vector_array_free(vjpVec)
|
||||
|
||||
rc := C.mlx_vjp(&outVec, &vjpVec, closure, primalsVec, cotangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: vjp failed")
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
vjps = vectorToArrays(vjpVec)
|
||||
return outputs, vjps, nil
|
||||
}
|
||||
|
||||
// --- JVP (Jacobian-Vector Product) — Forward Mode ---
|
||||
|
||||
// JVP computes the Jacobian-vector product (forward-mode autodiff).
|
||||
// Given a function fn, input primals, and input tangents (perturbation directions),
|
||||
// returns (outputs, output_tangents).
|
||||
//
|
||||
// Useful for directional derivatives and Hessian-vector products.
|
||||
func JVP(fn func([]*Array) []*Array, primals []*Array, tangents []*Array) (outputs []*Array, jvps []*Array, err error) {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
primalsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(primalsVec)
|
||||
for _, p := range primals {
|
||||
C.mlx_vector_array_append_value(primalsVec, p.ctx)
|
||||
}
|
||||
|
||||
tangentsVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(tangentsVec)
|
||||
for _, t := range tangents {
|
||||
C.mlx_vector_array_append_value(tangentsVec, t.ctx)
|
||||
}
|
||||
|
||||
var outVec, jvpVec C.mlx_vector_array
|
||||
outVec = C.mlx_vector_array_new()
|
||||
jvpVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
defer C.mlx_vector_array_free(jvpVec)
|
||||
|
||||
rc := C.mlx_jvp(&outVec, &jvpVec, closure, primalsVec, tangentsVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: jvp failed")
|
||||
}
|
||||
|
||||
outputs = vectorToArrays(outVec)
|
||||
jvps = vectorToArrays(jvpVec)
|
||||
return outputs, jvps, nil
|
||||
}
|
||||
|
||||
// --- ValueAndGrad — Combined Forward + Backward ---
|
||||
|
||||
// GradFn is a function that computes both the loss value and gradients
|
||||
// with respect to specified arguments. Call it with inputs to get
|
||||
// (values, gradients).
|
||||
type GradFn struct {
|
||||
cls C.mlx_closure_value_and_grad
|
||||
}
|
||||
|
||||
// ValueAndGrad creates a GradFn that computes both the function value and
|
||||
// gradients with respect to the arguments at the given indices.
|
||||
//
|
||||
// The returned GradFn can be called repeatedly with different inputs.
|
||||
// This is the primary API for training loops:
|
||||
//
|
||||
// lossFn := func(params []*Array) []*Array { return []*Array{loss} }
|
||||
// grad := mlx.ValueAndGrad(lossFn, 0) // differentiate w.r.t. first arg
|
||||
// values, grads := grad.Apply(params...)
|
||||
func ValueAndGrad(fn func([]*Array) []*Array, argnums ...int) *GradFn {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
// Default: differentiate w.r.t. first argument
|
||||
if len(argnums) == 0 {
|
||||
argnums = []int{0}
|
||||
}
|
||||
|
||||
cArgs := make([]C.int, len(argnums))
|
||||
for i, a := range argnums {
|
||||
cArgs[i] = C.int(a)
|
||||
}
|
||||
|
||||
g := &GradFn{}
|
||||
C.mlx_value_and_grad(&g.cls, closure, &cArgs[0], C.size_t(len(cArgs)))
|
||||
return g
|
||||
}
|
||||
|
||||
// Apply calls the gradient function with the given inputs.
|
||||
// Returns (values, gradients) where values are the function outputs
|
||||
// and gradients are w.r.t. the arguments specified in ValueAndGrad.
|
||||
func (g *GradFn) Apply(inputs ...*Array) (values []*Array, grads []*Array, err error) {
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inputVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, in.ctx)
|
||||
}
|
||||
|
||||
var valVec, gradVec C.mlx_vector_array
|
||||
valVec = C.mlx_vector_array_new()
|
||||
gradVec = C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(valVec)
|
||||
defer C.mlx_vector_array_free(gradVec)
|
||||
|
||||
rc := C.mlx_closure_value_and_grad_apply(&valVec, &gradVec, g.cls, inputVec)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return nil, nil, fmt.Errorf("mlx: value_and_grad apply failed")
|
||||
}
|
||||
|
||||
values = vectorToArrays(valVec)
|
||||
grads = vectorToArrays(gradVec)
|
||||
return values, grads, nil
|
||||
}
|
||||
|
||||
// Free releases the underlying C closure.
|
||||
func (g *GradFn) Free() {
|
||||
if g.cls.ctx != nil {
|
||||
C.mlx_closure_value_and_grad_free(g.cls)
|
||||
g.cls.ctx = nil
|
||||
}
|
||||
}
|
||||
|
||||
// --- Checkpoint — Memory-Efficient Gradient Recomputation ---
|
||||
|
||||
// Checkpoint wraps a function so that during backward pass, intermediate
|
||||
// activations are recomputed rather than stored. Trades compute for memory.
|
||||
//
|
||||
// Use this for memory-constrained training (large models on limited VRAM).
|
||||
func Checkpoint(fn func([]*Array) []*Array) func([]*Array) []*Array {
|
||||
Init()
|
||||
|
||||
closure := newClosure(fn)
|
||||
defer C.mlx_closure_free(closure)
|
||||
|
||||
var checkpointed C.mlx_closure
|
||||
C.mlx_checkpoint(&checkpointed, closure)
|
||||
|
||||
// Register the checkpointed closure so it stays alive
|
||||
id := gradNextID.Add(1)
|
||||
cpClosure := checkpointed // capture for the returned function
|
||||
|
||||
// Return a Go function that calls through the checkpointed closure
|
||||
return func(inputs []*Array) []*Array {
|
||||
_ = id // prevent GC of closure reference
|
||||
|
||||
inputVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(inputVec)
|
||||
for _, in := range inputs {
|
||||
C.mlx_vector_array_append_value(inputVec, in.ctx)
|
||||
}
|
||||
|
||||
outVec := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(outVec)
|
||||
|
||||
C.mlx_closure_apply(&outVec, cpClosure, inputVec)
|
||||
return vectorToArrays(outVec)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Loss Functions ---
|
||||
|
||||
// CrossEntropyLoss computes cross-entropy loss between logits and integer targets.
|
||||
// logits: [..., V] (raw model output, pre-softmax, last dim = vocab)
|
||||
// targets: [...] (integer token IDs, same shape as logits minus last dim)
|
||||
// Returns scalar loss averaged over all positions.
|
||||
//
|
||||
// Numerically stable: loss_i = logsumexp(logits_i) - logits_i[target_i]
|
||||
func CrossEntropyLoss(logits, targets *Array) *Array {
|
||||
Init()
|
||||
// LogSumExp along last axis (vocab dimension) for numerical stability
|
||||
lse := LogSumExp(logits, -1, false)
|
||||
|
||||
// Gather the logit at the target index: logits[..., target]
|
||||
// targets needs to be [..., 1] for take_along_axis
|
||||
tgtExpanded := ExpandDims(targets, -1)
|
||||
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
|
||||
gathered = Squeeze(gathered, -1) // back to [...] shape
|
||||
|
||||
// Per-position loss: logsumexp - logit_at_target
|
||||
perPos := Subtract(lse, gathered)
|
||||
|
||||
// Mean over all positions
|
||||
return MeanAll(perPos)
|
||||
}
|
||||
|
||||
// MaskedCrossEntropyLoss computes cross-entropy loss only on masked positions.
|
||||
// logits: [B, L, V], targets: [B, L], mask: [B, L] (1.0 = compute loss, 0.0 = ignore)
|
||||
// Returns scalar loss averaged over masked positions only.
|
||||
func MaskedCrossEntropyLoss(logits, targets, mask *Array) *Array {
|
||||
Init()
|
||||
lse := LogSumExp(logits, -1, false)
|
||||
tgtExpanded := ExpandDims(targets, -1)
|
||||
gathered := TakeAlongAxis(logits, tgtExpanded, -1)
|
||||
gathered = Squeeze(gathered, -1)
|
||||
perPos := Subtract(lse, gathered) // [B, L]
|
||||
|
||||
// Apply mask and average over masked positions
|
||||
masked := Mul(perPos, mask)
|
||||
return Divide(SumAll(masked), SumAll(mask))
|
||||
}
|
||||
|
||||
// MSELoss computes the mean squared error loss: mean((predictions - targets)^2).
|
||||
func MSELoss(predictions, targets *Array) *Array {
|
||||
diff := Subtract(predictions, targets)
|
||||
sq := Square(diff)
|
||||
return MeanAll(sq)
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
// vectorToArrays extracts all arrays from an mlx_vector_array.
|
||||
func vectorToArrays(vec C.mlx_vector_array) []*Array {
|
||||
n := int(C.mlx_vector_array_size(vec))
|
||||
out := make([]*Array, n)
|
||||
for i := 0; i < n; i++ {
|
||||
a := New("VEC_OUT")
|
||||
C.mlx_vector_array_get(&a.ctx, vec, C.size_t(i))
|
||||
out[i] = a
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Log returns element-wise natural logarithm.
|
||||
func Log(a *Array) *Array {
|
||||
out := New("LOG", a)
|
||||
C.mlx_log(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SumAll reduces by summation over all elements, returning a scalar.
|
||||
func SumAll(a *Array) *Array {
|
||||
flat := Reshape(a, -1)
|
||||
return Sum(flat, 0, false)
|
||||
}
|
||||
|
||||
// MeanAll reduces by averaging over all elements, returning a scalar.
|
||||
func MeanAll(a *Array) *Array {
|
||||
flat := Reshape(a, -1)
|
||||
return Mean(flat, 0, false)
|
||||
}
|
||||
|
||||
// OnesLike creates an array of ones with the same shape and type as the input.
|
||||
func OnesLike(a *Array) *Array {
|
||||
out := New("ONES_LIKE", a)
|
||||
C.mlx_ones_like(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
332
mlx/grad_test.go
332
mlx/grad_test.go
|
|
@ -1,332 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestVJP_SimpleSquare(t *testing.T) {
|
||||
// f(x) = x^2, df/dx = 2x
|
||||
// At x=3: f(3)=9, df/dx=6
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
cotangent := FromValue(float32(1.0)) // upstream grad = 1
|
||||
|
||||
outputs, grads, err := VJP(fn, []*Array{x}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], grads[0])
|
||||
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-9.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 9.0", got)
|
||||
}
|
||||
|
||||
grad := grads[0].Float()
|
||||
if math.Abs(grad-6.0) > 1e-5 {
|
||||
t.Errorf("grad = %f, want 6.0", grad)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVJP_Addition(t *testing.T) {
|
||||
// f(x, y) = x + y, df/dx = 1, df/dy = 1
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Add(inputs[0], inputs[1])}
|
||||
}
|
||||
|
||||
x := FromValue(float32(2.0))
|
||||
y := FromValue(float32(5.0))
|
||||
cotangent := FromValue(float32(1.0))
|
||||
|
||||
_, grads, err := VJP(fn, []*Array{x, y}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(grads...)
|
||||
|
||||
if math.Abs(grads[0].Float()-1.0) > 1e-5 {
|
||||
t.Errorf("dx = %f, want 1.0", grads[0].Float())
|
||||
}
|
||||
if math.Abs(grads[1].Float()-1.0) > 1e-5 {
|
||||
t.Errorf("dy = %f, want 1.0", grads[1].Float())
|
||||
}
|
||||
}
|
||||
|
||||
func TestVJP_MatmulGrad(t *testing.T) {
|
||||
// f(W) = sum(W @ x) — gradient of sum(matmul) w.r.t. W
|
||||
// For W=[2,2], x=[2,1]: dL/dW = ones @ x^T
|
||||
x := FromValues([]float32{1.0, 2.0}, 2, 1)
|
||||
w := FromValues([]float32{1.0, 0.0, 0.0, 1.0}, 2, 2) // identity
|
||||
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
result := Matmul(inputs[0], x)
|
||||
return []*Array{SumAll(result)}
|
||||
}
|
||||
|
||||
cotangent := FromValue(float32(1.0))
|
||||
|
||||
outputs, grads, err := VJP(fn, []*Array{w}, []*Array{cotangent})
|
||||
if err != nil {
|
||||
t.Fatalf("VJP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], grads[0])
|
||||
|
||||
// W @ x with W=I, x=[1,2]^T gives [1,2]^T, sum=3
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-3.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 3.0", got)
|
||||
}
|
||||
|
||||
// Gradient of sum(W@x) w.r.t. W is outer product: ones @ x^T
|
||||
// = [[1,2],[1,2]]
|
||||
gradFloats := grads[0].Floats()
|
||||
expected := []float32{1.0, 2.0, 1.0, 2.0}
|
||||
for i, exp := range expected {
|
||||
if math.Abs(float64(gradFloats[i]-exp)) > 1e-5 {
|
||||
t.Errorf("grad[%d] = %f, want %f", i, gradFloats[i], exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestJVP_SimpleSquare(t *testing.T) {
|
||||
// f(x) = x^2, JVP with tangent v: df = 2x * v
|
||||
// At x=3, v=1: df = 6
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
tangent := FromValue(float32(1.0))
|
||||
|
||||
outputs, jvps, err := JVP(fn, []*Array{x}, []*Array{tangent})
|
||||
if err != nil {
|
||||
t.Fatalf("JVP failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(outputs[0], jvps[0])
|
||||
|
||||
got := outputs[0].Float()
|
||||
if math.Abs(got-9.0) > 1e-5 {
|
||||
t.Errorf("output = %f, want 9.0", got)
|
||||
}
|
||||
|
||||
jvp := jvps[0].Float()
|
||||
if math.Abs(jvp-6.0) > 1e-5 {
|
||||
t.Errorf("jvp = %f, want 6.0", jvp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_Quadratic(t *testing.T) {
|
||||
// f(x) = x^2 + 2x + 1 = (x+1)^2
|
||||
// f'(x) = 2x + 2
|
||||
// At x=3: f(3) = 16, f'(3) = 8
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
x2 := Mul(x, x)
|
||||
two_x := MulScalar(x, 2.0)
|
||||
one := FromValue(float32(1.0))
|
||||
return []*Array{Add(Add(x2, two_x), one)}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(fn, 0)
|
||||
defer grad.Free()
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
values, grads, err := grad.Apply(x)
|
||||
if err != nil {
|
||||
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(values[0], grads[0])
|
||||
|
||||
val := values[0].Float()
|
||||
if math.Abs(val-16.0) > 1e-5 {
|
||||
t.Errorf("value = %f, want 16.0", val)
|
||||
}
|
||||
|
||||
g := grads[0].Float()
|
||||
if math.Abs(g-8.0) > 1e-5 {
|
||||
t.Errorf("grad = %f, want 8.0", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_MultiArg(t *testing.T) {
|
||||
// f(x, y) = x*y, df/dx = y, df/dy = x
|
||||
// At x=3, y=4: f=12, dx=4, dy=3
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Mul(inputs[0], inputs[1])}
|
||||
}
|
||||
|
||||
// Differentiate w.r.t. both arguments
|
||||
grad := ValueAndGrad(fn, 0, 1)
|
||||
defer grad.Free()
|
||||
|
||||
x := FromValue(float32(3.0))
|
||||
y := FromValue(float32(4.0))
|
||||
values, grads, err := grad.Apply(x, y)
|
||||
if err != nil {
|
||||
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(values[0], grads[0], grads[1])
|
||||
|
||||
val := values[0].Float()
|
||||
if math.Abs(val-12.0) > 1e-5 {
|
||||
t.Errorf("value = %f, want 12.0", val)
|
||||
}
|
||||
|
||||
dx := grads[0].Float()
|
||||
if math.Abs(dx-4.0) > 1e-5 {
|
||||
t.Errorf("dx = %f, want 4.0 (y)", dx)
|
||||
}
|
||||
|
||||
dy := grads[1].Float()
|
||||
if math.Abs(dy-3.0) > 1e-5 {
|
||||
t.Errorf("dy = %f, want 3.0 (x)", dy)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValueAndGrad_Reusable(t *testing.T) {
|
||||
// Verify GradFn can be called multiple times
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)} // x^2, grad = 2x
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(fn)
|
||||
defer grad.Free()
|
||||
|
||||
for _, tc := range []struct {
|
||||
x float32
|
||||
want float64 // expected gradient
|
||||
}{
|
||||
{2.0, 4.0},
|
||||
{5.0, 10.0},
|
||||
{-3.0, -6.0},
|
||||
{0.0, 0.0},
|
||||
} {
|
||||
x := FromValue(tc.x)
|
||||
_, grads, err := grad.Apply(x)
|
||||
if err != nil {
|
||||
t.Fatalf("Apply failed for x=%f: %v", tc.x, err)
|
||||
}
|
||||
Materialize(grads[0])
|
||||
|
||||
g := grads[0].Float()
|
||||
if math.Abs(g-tc.want) > 1e-5 {
|
||||
t.Errorf("x=%f: grad = %f, want %f", tc.x, g, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCrossEntropyLoss(t *testing.T) {
|
||||
// Simple 3-class classification
|
||||
// logits = [1.0, 2.0, 3.0], target = 2 (class index)
|
||||
// Manual: logsumexp([1,2,3]) = 3 + log(exp(-2)+exp(-1)+1)
|
||||
// = 3 + log(0.1353 + 0.3679 + 1.0) = 3 + log(1.5032) = 3.4076
|
||||
// loss = 3.4076 - 3.0 = 0.4076
|
||||
logits := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3) // [1, 3]
|
||||
targets := FromValues([]int32{2}, 1) // [1]
|
||||
|
||||
loss := CrossEntropyLoss(logits, targets)
|
||||
Materialize(loss)
|
||||
|
||||
got := loss.Float()
|
||||
expected := 0.4076
|
||||
if math.Abs(got-expected) > 0.01 {
|
||||
t.Errorf("CrossEntropyLoss = %f, want ~%f", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSELoss(t *testing.T) {
|
||||
pred := FromValues([]float32{1.0, 2.0, 3.0}, 3)
|
||||
target := FromValues([]float32{1.5, 2.5, 3.5}, 3)
|
||||
|
||||
loss := MSELoss(pred, target)
|
||||
Materialize(loss)
|
||||
|
||||
// MSE = mean((0.5)^2, (0.5)^2, (0.5)^2) = mean(0.25, 0.25, 0.25) = 0.25
|
||||
got := loss.Float()
|
||||
if math.Abs(got-0.25) > 1e-5 {
|
||||
t.Errorf("MSELoss = %f, want 0.25", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogSumExp(t *testing.T) {
|
||||
// logsumexp([1, 2, 3]) along axis -1
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0}, 1, 3)
|
||||
result := LogSumExp(a, -1, false)
|
||||
Materialize(result)
|
||||
|
||||
// = 3 + log(exp(-2) + exp(-1) + 1) = 3 + log(1.5032) ≈ 3.4076
|
||||
got := result.Float()
|
||||
expected := 3.4076
|
||||
if math.Abs(got-expected) > 0.01 {
|
||||
t.Errorf("LogSumExp = %f, want ~%f", got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOnesLike(t *testing.T) {
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0}, 3)
|
||||
ones := OnesLike(a)
|
||||
Materialize(ones)
|
||||
|
||||
floats := ones.Floats()
|
||||
for i, f := range floats {
|
||||
if f != 1.0 {
|
||||
t.Errorf("OnesLike[%d] = %f, want 1.0", i, f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckpoint(t *testing.T) {
|
||||
// Checkpoint should produce the same result as the original function
|
||||
fn := func(inputs []*Array) []*Array {
|
||||
x := inputs[0]
|
||||
return []*Array{Mul(x, x)}
|
||||
}
|
||||
|
||||
cpFn := Checkpoint(fn)
|
||||
|
||||
x := FromValue(float32(5.0))
|
||||
result := cpFn([]*Array{x})
|
||||
Materialize(result[0])
|
||||
|
||||
got := result[0].Float()
|
||||
if math.Abs(got-25.0) > 1e-5 {
|
||||
t.Errorf("Checkpoint result = %f, want 25.0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSumAll(t *testing.T) {
|
||||
a := FromValues([]float32{1.0, 2.0, 3.0, 4.0}, 2, 2)
|
||||
result := SumAll(a)
|
||||
Materialize(result)
|
||||
|
||||
got := result.Float()
|
||||
if math.Abs(got-10.0) > 1e-5 {
|
||||
t.Errorf("SumAll = %f, want 10.0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMeanAll(t *testing.T) {
|
||||
a := FromValues([]float32{2.0, 4.0, 6.0, 8.0}, 2, 2)
|
||||
result := MeanAll(a)
|
||||
Materialize(result)
|
||||
|
||||
got := result.Float()
|
||||
if math.Abs(got-5.0) > 1e-5 {
|
||||
t.Errorf("MeanAll = %f, want 5.0", got)
|
||||
}
|
||||
}
|
||||
63
mlx/io.go
63
mlx/io.go
|
|
@ -1,63 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
|
||||
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
|
||||
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
||||
Init()
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
arr := &Array{ctx: value, name: name}
|
||||
runtime.SetFinalizer(arr, finalizeArray)
|
||||
if !yield(name, arr) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
|
||||
func LoadAllSafetensors(path string) map[string]*Array {
|
||||
tensors := make(map[string]*Array)
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
tensors[name] = arr
|
||||
}
|
||||
return tensors
|
||||
}
|
||||
237
mlx/lora.go
237
mlx/lora.go
|
|
@ -1,237 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"sort"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// LoRALinear wraps a frozen Linear layer with low-rank trainable adapters.
|
||||
//
|
||||
// Forward: base(x) + scale * (x @ A^T) @ B^T
|
||||
//
|
||||
// A is [rank, in_features] — initialised with Kaiming normal
|
||||
// B is [out_features, rank] — initialised to zero
|
||||
// Scale = alpha / rank
|
||||
//
|
||||
// Only A and B are trainable. The base weights stay frozen.
|
||||
type LoRALinear struct {
|
||||
Base *Linear // Frozen base weights (may be quantized)
|
||||
A *Array // [rank, in_features] — trainable
|
||||
B *Array // [out_features, rank] — trainable
|
||||
Scale float32 // alpha / rank
|
||||
Rank int
|
||||
Alpha float32
|
||||
}
|
||||
|
||||
// NewLoRALinear wraps an existing Linear layer with LoRA adapters.
|
||||
// rank: decomposition rank (typically 4, 8, or 16)
|
||||
// alpha: scaling factor (typically 2*rank)
|
||||
func NewLoRALinear(base *Linear, rank int, alpha float32) *LoRALinear {
|
||||
// Determine dimensions from the base weight.
|
||||
// Weight shape is [out_features, in_features] for standard linear,
|
||||
// or quantized shape which we handle via the base layer.
|
||||
var inFeatures, outFeatures int32
|
||||
|
||||
if base.Scales != nil {
|
||||
// Quantized: weight is packed. Compute dims from scales.
|
||||
// scales shape is [out_features, in_features / group_size]
|
||||
scaleShape := base.Scales.Shape()
|
||||
outFeatures = scaleShape[0]
|
||||
inFeatures = scaleShape[1] * int32(base.GroupSize)
|
||||
} else {
|
||||
wShape := base.Weight.Shape()
|
||||
outFeatures = wShape[0]
|
||||
inFeatures = wShape[1]
|
||||
}
|
||||
|
||||
// A: Kaiming normal initialisation — N(0, 1/sqrt(in_features))
|
||||
stddev := float32(1.0 / math.Sqrt(float64(inFeatures)))
|
||||
a := RandomNormal(0, stddev, []int32{int32(rank), inFeatures}, DTypeFloat32)
|
||||
|
||||
// B: zero initialisation — LoRA starts as identity (no change to base)
|
||||
b := Zeros([]int32{outFeatures, int32(rank)}, DTypeFloat32)
|
||||
|
||||
Materialize(a, b)
|
||||
|
||||
return &LoRALinear{
|
||||
Base: base,
|
||||
A: a,
|
||||
B: b,
|
||||
Scale: alpha / float32(rank),
|
||||
Rank: rank,
|
||||
Alpha: alpha,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward computes: base(x) + scale * (x @ A^T) @ B^T
|
||||
// Calls baseForward on the underlying Linear to avoid infinite recursion
|
||||
// when the Linear's Forward method dispatches through LoRA.
|
||||
func (l *LoRALinear) Forward(x *Array) *Array {
|
||||
baseOut := l.Base.baseForward(x)
|
||||
|
||||
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
|
||||
loraOut := Matmul(x, Transpose(l.A))
|
||||
loraOut = Matmul(loraOut, Transpose(l.B))
|
||||
loraOut = MulScalar(loraOut, l.Scale)
|
||||
|
||||
return Add(baseOut, loraOut)
|
||||
}
|
||||
|
||||
// TrainableParams returns the LoRA A and B arrays for gradient computation.
|
||||
func (l *LoRALinear) TrainableParams() []*Array {
|
||||
return []*Array{l.A, l.B}
|
||||
}
|
||||
|
||||
// SetParams updates the LoRA A and B arrays (used by optimiser after gradient step).
|
||||
func (l *LoRALinear) SetParams(a, b *Array) {
|
||||
l.A = a
|
||||
l.B = b
|
||||
}
|
||||
|
||||
// ParamCount returns the number of trainable parameters.
|
||||
func (l *LoRALinear) ParamCount() int {
|
||||
aShape := l.A.Shape()
|
||||
bShape := l.B.Shape()
|
||||
return int(aShape[0]*aShape[1] + bShape[0]*bShape[1])
|
||||
}
|
||||
|
||||
// --- LoRA Application to Models ---
|
||||
|
||||
// LoRAConfig specifies which layers to apply LoRA to and with what parameters.
|
||||
type LoRAConfig struct {
|
||||
Rank int // Decomposition rank (default 8)
|
||||
Alpha float32 // Scaling factor (default 16)
|
||||
TargetKeys []string // Weight name suffixes to target (default: q_proj, v_proj)
|
||||
}
|
||||
|
||||
// DefaultLoRAConfig returns the standard LoRA configuration for LLM fine-tuning.
|
||||
func DefaultLoRAConfig() LoRAConfig {
|
||||
return LoRAConfig{
|
||||
Rank: 8,
|
||||
Alpha: 16,
|
||||
TargetKeys: []string{"q_proj", "v_proj"},
|
||||
}
|
||||
}
|
||||
|
||||
// LoRAAdapter holds all LoRA layers applied to a model.
|
||||
type LoRAAdapter struct {
|
||||
Layers map[string]*LoRALinear // keyed by weight path prefix
|
||||
Config LoRAConfig
|
||||
}
|
||||
|
||||
// TotalParams returns the total number of trainable parameters across all LoRA layers.
|
||||
func (a *LoRAAdapter) TotalParams() int {
|
||||
total := 0
|
||||
for _, l := range a.Layers {
|
||||
total += l.ParamCount()
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// SortedNames returns layer names in deterministic sorted order.
|
||||
func (a *LoRAAdapter) SortedNames() []string {
|
||||
names := make([]string, 0, len(a.Layers))
|
||||
for name := range a.Layers {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// AllTrainableParams returns all trainable arrays (A and B from every layer),
|
||||
// in a deterministic order sorted by layer name.
|
||||
func (a *LoRAAdapter) AllTrainableParams() []*Array {
|
||||
names := a.SortedNames()
|
||||
params := make([]*Array, 0, len(names)*2)
|
||||
for _, name := range names {
|
||||
l := a.Layers[name]
|
||||
params = append(params, l.A, l.B)
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
// SetAllParams updates all LoRA A and B arrays from a flat slice,
|
||||
// in the same deterministic order as AllTrainableParams.
|
||||
func (a *LoRAAdapter) SetAllParams(params []*Array) {
|
||||
names := a.SortedNames()
|
||||
for i, name := range names {
|
||||
l := a.Layers[name]
|
||||
l.A = params[i*2]
|
||||
l.B = params[i*2+1]
|
||||
}
|
||||
}
|
||||
|
||||
// Save writes the LoRA adapter weights to a safetensors file.
|
||||
// Only saves the A and B matrices — not the frozen base weights.
|
||||
func (a *LoRAAdapter) Save(path string) error {
|
||||
weights := make(map[string]*Array)
|
||||
for name, l := range a.Layers {
|
||||
weights[name+".lora_a"] = l.A
|
||||
weights[name+".lora_b"] = l.B
|
||||
}
|
||||
return SaveSafetensors(path, weights)
|
||||
}
|
||||
|
||||
// --- Random Normal ---
|
||||
|
||||
// RandomNormal generates normal (Gaussian) random values with given mean and stddev.
|
||||
func RandomNormal(mean, stddev float32, shape []int32, dtype DType) *Array {
|
||||
Init()
|
||||
out := New("RANDOM_NORMAL")
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_normal(
|
||||
&out.ctx,
|
||||
&cShape[0], C.size_t(len(cShape)),
|
||||
C.mlx_dtype(dtype),
|
||||
C.float(mean), C.float(stddev),
|
||||
key, // null key = use default RNG
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- SaveSafetensors ---
|
||||
|
||||
// SaveSafetensors saves a map of named arrays to a .safetensors file.
|
||||
func SaveSafetensors(path string, weights map[string]*Array) error {
|
||||
Init()
|
||||
|
||||
// Build the map
|
||||
cMap := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(cMap)
|
||||
|
||||
for name, arr := range weights {
|
||||
cName := C.CString(name)
|
||||
C.mlx_map_string_to_array_insert(cMap, cName, arr.ctx)
|
||||
C.free(unsafe.Pointer(cName))
|
||||
}
|
||||
|
||||
// Empty metadata
|
||||
cMeta := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(cMeta)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
rc := C.mlx_save_safetensors(cPath, cMap, cMeta)
|
||||
if rc != 0 {
|
||||
checkError()
|
||||
return fmt.Errorf("mlx: save safetensors failed: %s", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
316
mlx/lora_test.go
316
mlx/lora_test.go
|
|
@ -1,316 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewLoRALinear(t *testing.T) {
|
||||
// Create a simple base linear layer: [4, 8] weight
|
||||
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 4, 8.0) // rank=4, alpha=8
|
||||
|
||||
// Check dimensions
|
||||
aShape := lora.A.Shape()
|
||||
bShape := lora.B.Shape()
|
||||
|
||||
if aShape[0] != 4 || aShape[1] != 8 {
|
||||
t.Errorf("A shape = %v, want [4, 8]", aShape)
|
||||
}
|
||||
if bShape[0] != 4 || bShape[1] != 4 {
|
||||
t.Errorf("B shape = %v, want [4, 4]", bShape)
|
||||
}
|
||||
|
||||
// Scale should be alpha/rank = 8/4 = 2
|
||||
if math.Abs(float64(lora.Scale)-2.0) > 1e-5 {
|
||||
t.Errorf("Scale = %f, want 2.0", lora.Scale)
|
||||
}
|
||||
|
||||
// B should be all zeros (LoRA starts as identity)
|
||||
Materialize(lora.B)
|
||||
bFloats := lora.B.Floats()
|
||||
for i, v := range bFloats {
|
||||
if v != 0 {
|
||||
t.Errorf("B[%d] = %f, want 0", i, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRALinear_ForwardMatchesBase(t *testing.T) {
|
||||
// With B=0, LoRA forward should equal base forward
|
||||
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 4, 8.0)
|
||||
|
||||
// Random input [1, 3, 8]
|
||||
x := RandomNormal(0, 1, []int32{1, 3, 8}, DTypeFloat32)
|
||||
Materialize(x)
|
||||
|
||||
baseOut := base.Forward(x)
|
||||
loraOut := lora.Forward(x)
|
||||
Materialize(baseOut, loraOut)
|
||||
|
||||
// Should be identical since B is zero
|
||||
baseFloats := baseOut.Floats()
|
||||
loraFloats := loraOut.Floats()
|
||||
|
||||
if len(baseFloats) != len(loraFloats) {
|
||||
t.Fatalf("output sizes differ: base=%d, lora=%d", len(baseFloats), len(loraFloats))
|
||||
}
|
||||
|
||||
for i := range baseFloats {
|
||||
diff := math.Abs(float64(baseFloats[i] - loraFloats[i]))
|
||||
if diff > 1e-4 {
|
||||
t.Errorf("output[%d] differs: base=%f, lora=%f", i, baseFloats[i], loraFloats[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRALinear_ForwardWithAdapter(t *testing.T) {
|
||||
// Set A and B to known values and verify output changes
|
||||
w := Zeros([]int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 2, 4.0) // rank=2, alpha=4, scale=2
|
||||
|
||||
// Set A to identity-like: [[1,0,0,...], [0,1,0,...]]
|
||||
a := Zeros([]int32{2, 8}, DTypeFloat32)
|
||||
// Set B to ones: [[1,1], [1,1], [1,1], [1,1]]
|
||||
b := FromValues([]float32{
|
||||
1, 1,
|
||||
1, 1,
|
||||
1, 1,
|
||||
1, 1,
|
||||
}, 4, 2)
|
||||
Materialize(a, b)
|
||||
lora.A = a
|
||||
lora.B = b
|
||||
|
||||
// With base=0, A=0, output should also be 0 (scale * x@0@B^T = 0)
|
||||
x := FromValues([]float32{1, 2, 3, 4, 5, 6, 7, 8}, 1, 1, 8)
|
||||
result := lora.Forward(x)
|
||||
Materialize(result)
|
||||
|
||||
// base(x) = 0 (zero weights), lora = scale * (x @ A^T) @ B^T
|
||||
// A is zeros, so x @ A^T = [0, 0], then @ B^T = [0,0,0,0]
|
||||
for _, v := range result.Floats() {
|
||||
if v != 0 {
|
||||
t.Errorf("expected 0 with zero A, got %f", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRALinear_ParamCount(t *testing.T) {
|
||||
w := RandomNormal(0, 0.01, []int32{64, 128}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 8, 16.0) // rank=8
|
||||
// A: [8, 128] = 1024, B: [64, 8] = 512, total = 1536
|
||||
expected := 8*128 + 64*8
|
||||
if lora.ParamCount() != expected {
|
||||
t.Errorf("ParamCount = %d, want %d", lora.ParamCount(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRALinear_TrainableParams(t *testing.T) {
|
||||
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 4, 8.0)
|
||||
params := lora.TrainableParams()
|
||||
|
||||
if len(params) != 2 {
|
||||
t.Fatalf("TrainableParams returned %d arrays, want 2", len(params))
|
||||
}
|
||||
|
||||
// First is A, second is B
|
||||
if params[0].Shape()[0] != 4 || params[0].Shape()[1] != 8 {
|
||||
t.Errorf("param[0] (A) shape = %v, want [4, 8]", params[0].Shape())
|
||||
}
|
||||
if params[1].Shape()[0] != 4 || params[1].Shape()[1] != 4 {
|
||||
t.Errorf("param[1] (B) shape = %v, want [4, 4]", params[1].Shape())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRALinear_GradientFlows(t *testing.T) {
|
||||
// Verify that gradients flow through the LoRA path
|
||||
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 4, 8.0)
|
||||
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
||||
Materialize(x)
|
||||
|
||||
// Loss function: sum of LoRA output (differentiating w.r.t. A and B)
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
lora.A = inputs[0]
|
||||
lora.B = inputs[1]
|
||||
out := lora.Forward(x)
|
||||
return []*Array{SumAll(out)}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn, 0, 1) // grad w.r.t. A and B
|
||||
defer grad.Free()
|
||||
|
||||
values, grads, err := grad.Apply(lora.A, lora.B)
|
||||
if err != nil {
|
||||
t.Fatalf("ValueAndGrad failed: %v", err)
|
||||
}
|
||||
|
||||
Materialize(append(values, grads...)...)
|
||||
|
||||
// Loss should be a scalar
|
||||
loss := values[0].Float()
|
||||
t.Logf("loss = %f", loss)
|
||||
|
||||
// Gradients should be non-zero (A has random init, B is zero but gets grad)
|
||||
gradA := grads[0]
|
||||
gradB := grads[1]
|
||||
|
||||
aGradFloats := gradA.Floats()
|
||||
bGradFloats := gradB.Floats()
|
||||
|
||||
hasNonZeroA := false
|
||||
for _, v := range aGradFloats {
|
||||
if v != 0 {
|
||||
hasNonZeroA = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
hasNonZeroB := false
|
||||
for _, v := range bGradFloats {
|
||||
if v != 0 {
|
||||
hasNonZeroB = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// A gradient might be zero if B is zero (since dL/dA depends on B)
|
||||
// But B gradient should be non-zero since A is random
|
||||
if !hasNonZeroB {
|
||||
t.Error("gradient for B is all zeros — gradients not flowing")
|
||||
}
|
||||
t.Logf("gradA has non-zero: %v, gradB has non-zero: %v", hasNonZeroA, hasNonZeroB)
|
||||
}
|
||||
|
||||
func TestRandomNormal(t *testing.T) {
|
||||
arr := RandomNormal(0, 1, []int32{100}, DTypeFloat32)
|
||||
Materialize(arr)
|
||||
|
||||
floats := arr.Floats()
|
||||
if len(floats) != 100 {
|
||||
t.Fatalf("RandomNormal returned %d elements, want 100", len(floats))
|
||||
}
|
||||
|
||||
// Check rough statistics: mean should be near 0, values should have spread
|
||||
var sum float64
|
||||
for _, f := range floats {
|
||||
sum += float64(f)
|
||||
}
|
||||
mean := sum / 100
|
||||
if math.Abs(mean) > 0.5 { // generous tolerance for 100 samples
|
||||
t.Errorf("mean = %f, expected near 0", mean)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveSafetensors(t *testing.T) {
|
||||
a := FromValues([]float32{1, 2, 3, 4}, 2, 2)
|
||||
b := FromValues([]float32{5, 6, 7, 8, 9, 10}, 3, 2)
|
||||
Materialize(a, b)
|
||||
|
||||
path := t.TempDir() + "/test.safetensors"
|
||||
err := SaveSafetensors(path, map[string]*Array{
|
||||
"layer.lora_a": a,
|
||||
"layer.lora_b": b,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SaveSafetensors failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("saved file not found: %v", err)
|
||||
}
|
||||
if info.Size() == 0 {
|
||||
t.Error("saved file is empty")
|
||||
}
|
||||
|
||||
// Load it back
|
||||
loaded := LoadAllSafetensors(path)
|
||||
Materialize(loaded["layer.lora_a"], loaded["layer.lora_b"])
|
||||
|
||||
aLoaded := loaded["layer.lora_a"].Floats()
|
||||
bLoaded := loaded["layer.lora_b"].Floats()
|
||||
|
||||
expectedA := []float32{1, 2, 3, 4}
|
||||
expectedB := []float32{5, 6, 7, 8, 9, 10}
|
||||
|
||||
for i, v := range expectedA {
|
||||
if aLoaded[i] != v {
|
||||
t.Errorf("loaded A[%d] = %f, want %f", i, aLoaded[i], v)
|
||||
}
|
||||
}
|
||||
for i, v := range expectedB {
|
||||
if bLoaded[i] != v {
|
||||
t.Errorf("loaded B[%d] = %f, want %f", i, bLoaded[i], v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoRAAdapter_Save(t *testing.T) {
|
||||
w := RandomNormal(0, 0.01, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
adapter := &LoRAAdapter{
|
||||
Layers: map[string]*LoRALinear{
|
||||
"model.layers.0.self_attn.q_proj": NewLoRALinear(base, 4, 8.0),
|
||||
},
|
||||
Config: DefaultLoRAConfig(),
|
||||
}
|
||||
|
||||
path := t.TempDir() + "/adapter.safetensors"
|
||||
err := adapter.Save(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Adapter.Save failed: %v", err)
|
||||
}
|
||||
|
||||
// Load and verify
|
||||
loaded := LoadAllSafetensors(path)
|
||||
aKey := "model.layers.0.self_attn.q_proj.lora_a"
|
||||
bKey := "model.layers.0.self_attn.q_proj.lora_b"
|
||||
|
||||
if _, ok := loaded[aKey]; !ok {
|
||||
t.Errorf("missing key %s in saved adapter", aKey)
|
||||
}
|
||||
if _, ok := loaded[bKey]; !ok {
|
||||
t.Errorf("missing key %s in saved adapter", bKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultLoRAConfig(t *testing.T) {
|
||||
cfg := DefaultLoRAConfig()
|
||||
if cfg.Rank != 8 {
|
||||
t.Errorf("Rank = %d, want 8", cfg.Rank)
|
||||
}
|
||||
if cfg.Alpha != 16 {
|
||||
t.Errorf("Alpha = %f, want 16", cfg.Alpha)
|
||||
}
|
||||
if len(cfg.TargetKeys) != 2 {
|
||||
t.Errorf("TargetKeys = %v, want [q_proj, v_proj]", cfg.TargetKeys)
|
||||
}
|
||||
}
|
||||
115
mlx/mlx.go
115
mlx/mlx.go
|
|
@ -1,115 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
//
|
||||
// Build mlx-c before use:
|
||||
//
|
||||
// cd pkg/mlx && go generate ./...
|
||||
//
|
||||
// Build (MLX is auto-enabled on darwin/arm64):
|
||||
//
|
||||
// go build -o core .
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
//go:generate cmake --build build --parallel
|
||||
//go:generate cmake --install build
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: -std=c++17
|
||||
#cgo CFLAGS: -mmacosx-version-min=26.0
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/dist/include
|
||||
#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
static const char *last_mlx_error = NULL;
|
||||
|
||||
static void mlx_go_error_handler(const char *msg, void *data) {
|
||||
fprintf(stderr, "MLX ERROR: %s\n", msg);
|
||||
last_mlx_error = msg;
|
||||
}
|
||||
|
||||
static void set_error_handler() {
|
||||
mlx_set_error_handler(&mlx_go_error_handler, NULL, NULL);
|
||||
}
|
||||
|
||||
static const char* get_last_error() {
|
||||
return last_mlx_error;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var initOnce sync.Once
|
||||
|
||||
// Init sets up the MLX error handler. Called automatically on first use.
|
||||
func Init() {
|
||||
initOnce.Do(func() {
|
||||
C.set_error_handler()
|
||||
slog.Debug("mlx: initialized with Metal backend")
|
||||
})
|
||||
}
|
||||
|
||||
// checkError logs the last MLX error if any occurred.
|
||||
func checkError() {
|
||||
if msg := C.get_last_error(); msg != nil {
|
||||
slog.Error("mlx", "error", C.GoString(msg))
|
||||
}
|
||||
}
|
||||
|
||||
// Materialize synchronously evaluates arrays, computing their values on the GPU.
|
||||
// This is the MLX equivalent of forcing lazy computation to complete.
|
||||
func Materialize(outputs ...*Array) {
|
||||
doMaterialize(outputs, false)
|
||||
}
|
||||
|
||||
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
|
||||
func MaterializeAsync(outputs ...*Array) {
|
||||
doMaterialize(outputs, true)
|
||||
}
|
||||
|
||||
func doMaterialize(outputs []*Array, async bool) {
|
||||
Init()
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect gathers all valid arrays from a variadic list for batch Materialize.
|
||||
func Collect(arrays ...*Array) []*Array {
|
||||
var out []*Array
|
||||
for _, a := range arrays {
|
||||
if a != nil && a.Valid() {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MetalAvailable reports whether Metal GPU is available.
|
||||
func MetalAvailable() bool {
|
||||
Init()
|
||||
var available C.bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
//go:build !(darwin && arm64)
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
// This stub file is used on non-darwin/non-arm64 platforms or when the
|
||||
// mlx build tag is not set. All operations report MLX as unavailable.
|
||||
package mlx
|
||||
|
||||
// MetalAvailable reports whether Metal GPU is available.
|
||||
// Always returns false on non-Apple Silicon platforms.
|
||||
func MetalAvailable() bool { return false }
|
||||
|
|
@ -1,448 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds Gemma 3 text model configuration.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
Quantization *QuantizationConfig `json:"-"` // Parsed separately from top-level
|
||||
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// GemmaModel is the Gemma 3 text model.
|
||||
type GemmaModel struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
Layers []*DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear // Tied to EmbedTokens
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
NormScaled *mlx.Array
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Cfg *TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block.
|
||||
type DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule
|
||||
Attention *Attention
|
||||
PostAttnNorm *mlx.RMSNormModule
|
||||
PreFFNorm *mlx.RMSNormModule
|
||||
MLP *MLP
|
||||
PostFFNorm *mlx.RMSNormModule
|
||||
|
||||
// Precomputed scaled weights
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||
type Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
|
||||
QNormScaled *mlx.Array
|
||||
KNormScaled *mlx.Array
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network.
|
||||
type MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
// compiledGELU is a singleton for the compiled GELU function.
|
||||
var compiledGELU *mlx.CompiledFunc
|
||||
|
||||
func getCompiledGELU() *mlx.CompiledFunc {
|
||||
if compiledGELU == nil {
|
||||
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApprox(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGELU
|
||||
}
|
||||
|
||||
// geluApprox computes GELU using the tanh approximation:
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
const sqrt2OverPi = 0.7978845608028654
|
||||
const coeff = 0.044715
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
t := mlx.Tanh(scaled)
|
||||
onePlusT := mlx.AddScalar(t, 1.0)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
|
||||
}
|
||||
|
||||
// parseConfig handles both flat and nested (text_config) Gemma 3 configs.
|
||||
func parseConfig(data []byte) (*TextConfig, error) {
|
||||
// Try parsing text_config from multimodal wrapper
|
||||
var wrapper struct {
|
||||
TextConfig TextConfig `json:"text_config"`
|
||||
ModelType string `json:"model_type"`
|
||||
Quantization *QuantizationConfig `json:"quantization"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &wrapper); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := wrapper.TextConfig
|
||||
|
||||
// If text_config was empty, try top-level
|
||||
if cfg.NumHiddenLayers == 0 {
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Quantization is always top-level
|
||||
cfg.Quantization = wrapper.Quantization
|
||||
|
||||
// Compute scale (head_dim may be inferred later from weights if not in config)
|
||||
if cfg.HeadDim > 0 {
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
}
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.SlidingWindowPattern == 0 {
|
||||
cfg.SlidingWindowPattern = 6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 262208 // Gemma 3 default
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := parseConfig(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: parse config: %w", err)
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to resolve weight with language_model. prefix fallback
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Infer head_dim from q_proj weight shape when not in config.
|
||||
// Gemma 3 uses head_dim=256 which differs from hidden_size/num_heads.
|
||||
if cfg.HeadDim == 0 {
|
||||
qWeight := w("model.layers.0.self_attn.q_proj.weight")
|
||||
if qWeight != nil {
|
||||
qShape := qWeight.Shape()
|
||||
if len(qShape) > 0 {
|
||||
cfg.HeadDim = qShape[0] / cfg.NumAttentionHeads
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
slog.Info("mlx: inferred head_dim from q_proj weight", "head_dim", cfg.HeadDim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create linear layer (quantized or dense)
|
||||
q := cfg.Quantization
|
||||
if q != nil {
|
||||
slog.Info("mlx: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, nil, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, nil)
|
||||
}
|
||||
|
||||
// Create embedding (quantized or dense)
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
embed.GroupSize = q.GroupSize
|
||||
embed.Bits = q.Bits
|
||||
}
|
||||
|
||||
m := &GemmaModel{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
// Initialize layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(prefix + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_attention_layernorm.weight")},
|
||||
PreFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".pre_feedforward_layernorm.weight")},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: w(prefix + ".post_feedforward_layernorm.weight")},
|
||||
Attention: &Attention{
|
||||
QProj: linear(prefix + ".self_attn.q_proj"),
|
||||
KProj: linear(prefix + ".self_attn.k_proj"),
|
||||
VProj: linear(prefix + ".self_attn.v_proj"),
|
||||
OProj: linear(prefix + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(prefix + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: linear(prefix + ".mlp.gate_proj"),
|
||||
UpProj: linear(prefix + ".mlp.up_proj"),
|
||||
DownProj: linear(prefix + ".mlp.down_proj"),
|
||||
},
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Output head — check for separate lm_head first, else tie to embeddings
|
||||
lmHeadWeight := w("lm_head.weight")
|
||||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
// Tied embeddings — reuse embed_tokens weights (with quantization if present)
|
||||
m.Output = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
// Materialize all weights
|
||||
var allArrays []*mlx.Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func precomputeScaledWeights(m *GemmaModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Materialize(scaled...)
|
||||
}
|
||||
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false
|
||||
}
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass.
|
||||
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// GQA: repeat K/V heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer caches for generation.
|
||||
func (m *GemmaModel) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *GemmaModel) ModelType() string { return "gemma3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *GemmaModel) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
case "k_proj":
|
||||
proj = layer.Attention.KProj
|
||||
case "v_proj":
|
||||
proj = layer.Attention.VProj
|
||||
case "o_proj":
|
||||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adapter
|
||||
}
|
||||
|
|
@ -1,78 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Model is the common interface for all transformer model architectures.
|
||||
type Model interface {
|
||||
// Forward runs the model forward pass on token IDs with KV caches.
|
||||
Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
|
||||
|
||||
// NewCache creates per-layer KV caches for generation.
|
||||
NewCache() []cache.Cache
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
NumLayers() int
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
Tokenizer() *tokenizer.Tokenizer
|
||||
|
||||
// ModelType returns the architecture identifier (e.g. "gemma3", "qwen3").
|
||||
ModelType() string
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters for training.
|
||||
// Returns the adapter which holds references to all LoRA layers.
|
||||
ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter
|
||||
}
|
||||
|
||||
// QuantizationConfig holds quantization parameters from config.json.
|
||||
type QuantizationConfig struct {
|
||||
GroupSize int `json:"group_size"`
|
||||
Bits int `json:"bits"`
|
||||
}
|
||||
|
||||
// resolveWeight looks up a weight with optional "language_model." prefix.
|
||||
func resolveWeight(weights map[string]*mlx.Array, name string) *mlx.Array {
|
||||
if w, ok := weights[name]; ok {
|
||||
return w
|
||||
}
|
||||
if w, ok := weights["language_model."+name]; ok {
|
||||
return w
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadModel auto-detects the model architecture from config.json and loads it.
|
||||
func LoadModel(modelPath string) (Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model: load config: %w", err)
|
||||
}
|
||||
|
||||
var probe struct {
|
||||
ModelType string `json:"model_type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("model: parse model_type: %w", err)
|
||||
}
|
||||
|
||||
switch probe.ModelType {
|
||||
case "qwen3":
|
||||
return LoadQwen3(modelPath)
|
||||
case "gemma3", "gemma2":
|
||||
return LoadGemma3(modelPath)
|
||||
default:
|
||||
return nil, fmt.Errorf("model: unsupported architecture %q", probe.ModelType)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,337 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
"forge.lthn.ai/core/go-ai/mlx/cache"
|
||||
"forge.lthn.ai/core/go-ai/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// Qwen3Config holds Qwen 3 model configuration.
|
||||
type Qwen3Config struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
|
||||
Quantization *QuantizationConfig `json:"-"`
|
||||
Scale float32 `json:"-"` // 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// Qwen3Model is the Qwen 3 text model.
|
||||
type Qwen3Model struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
Layers []*Qwen3DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Cfg *Qwen3Config
|
||||
}
|
||||
|
||||
// Qwen3DecoderLayer is a single transformer block.
|
||||
// Qwen 3 uses standard pre-norm residual: norm→attn→add, norm→mlp→add.
|
||||
type Qwen3DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule // Pre-attention norm
|
||||
PostAttnNorm *mlx.RMSNormModule // Pre-MLP norm (confusingly named post_attention_layernorm)
|
||||
Attention *Qwen3Attention
|
||||
MLP *Qwen3MLP
|
||||
}
|
||||
|
||||
// Qwen3Attention implements Qwen 3 GQA with Q/K RMS normalization.
|
||||
type Qwen3Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
}
|
||||
|
||||
// Qwen3MLP is the SwiGLU feed-forward network: down(silu(gate(x)) * up(x)).
|
||||
type Qwen3MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
func parseQwen3Config(data []byte) (*Qwen3Config, error) {
|
||||
var cfg Qwen3Config
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Top-level quantization
|
||||
var wrapper struct {
|
||||
Quantization *QuantizationConfig `json:"quantization"`
|
||||
}
|
||||
json.Unmarshal(data, &wrapper)
|
||||
cfg.Quantization = wrapper.Quantization
|
||||
|
||||
// Compute scale
|
||||
if cfg.HeadDim == 0 {
|
||||
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
|
||||
}
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
|
||||
// Defaults
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.VocabSize == 0 {
|
||||
cfg.VocabSize = 151936
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// LoadQwen3 loads a Qwen 3 model from a safetensors directory.
|
||||
func LoadQwen3(modelPath string) (*Qwen3Model, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load config: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := parseQwen3Config(data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: parse config: %w", err)
|
||||
}
|
||||
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("qwen3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
w := func(name string) *mlx.Array { return resolveWeight(weights, name) }
|
||||
|
||||
// Quantization setup
|
||||
q := cfg.Quantization
|
||||
if q != nil {
|
||||
slog.Info("qwen3: using quantized inference", "bits", q.Bits, "group_size", q.GroupSize)
|
||||
}
|
||||
linear := func(prefix string) *mlx.Linear {
|
||||
weight := w(prefix + ".weight")
|
||||
scales := w(prefix + ".scales")
|
||||
biases := w(prefix + ".biases")
|
||||
bias := w(prefix + ".bias")
|
||||
if scales != nil && q != nil {
|
||||
return mlx.NewQuantizedLinear(weight, scales, biases, bias, q.GroupSize, q.Bits)
|
||||
}
|
||||
return mlx.NewLinear(weight, bias)
|
||||
}
|
||||
|
||||
// Embedding
|
||||
embed := &mlx.Embedding{Weight: w("model.embed_tokens.weight")}
|
||||
if embedScales := w("model.embed_tokens.scales"); embedScales != nil && q != nil {
|
||||
embed.Scales = embedScales
|
||||
embed.Biases = w("model.embed_tokens.biases")
|
||||
embed.GroupSize = q.GroupSize
|
||||
embed.Bits = q.Bits
|
||||
}
|
||||
|
||||
m := &Qwen3Model{
|
||||
EmbedTokens: embed,
|
||||
Layers: make([]*Qwen3DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: w("model.norm.weight")},
|
||||
Tok: tok,
|
||||
Cfg: cfg,
|
||||
}
|
||||
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
p := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &Qwen3DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: w(p + ".input_layernorm.weight")},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: w(p + ".post_attention_layernorm.weight")},
|
||||
Attention: &Qwen3Attention{
|
||||
QProj: linear(p + ".self_attn.q_proj"),
|
||||
KProj: linear(p + ".self_attn.k_proj"),
|
||||
VProj: linear(p + ".self_attn.v_proj"),
|
||||
OProj: linear(p + ".self_attn.o_proj"),
|
||||
QNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.q_norm.weight")},
|
||||
KNorm: &mlx.RMSNormModule{Weight: w(p + ".self_attn.k_norm.weight")},
|
||||
},
|
||||
MLP: &Qwen3MLP{
|
||||
GateProj: linear(p + ".mlp.gate_proj"),
|
||||
UpProj: linear(p + ".mlp.up_proj"),
|
||||
DownProj: linear(p + ".mlp.down_proj"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Output head — Qwen 3 has tie_word_embeddings=false, so lm_head is separate
|
||||
lmHeadWeight := w("lm_head.weight")
|
||||
if lmHeadWeight != nil {
|
||||
lmHeadScales := w("lm_head.scales")
|
||||
if lmHeadScales != nil && q != nil {
|
||||
m.Output = mlx.NewQuantizedLinear(lmHeadWeight, lmHeadScales, w("lm_head.biases"), nil, q.GroupSize, q.Bits)
|
||||
} else {
|
||||
m.Output = mlx.NewLinear(lmHeadWeight, nil)
|
||||
}
|
||||
} else {
|
||||
m.Output = m.EmbedTokens.AsLinear()
|
||||
}
|
||||
|
||||
// Materialise all weights onto Metal
|
||||
var allArrays []*mlx.Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
|
||||
slog.Info("qwen3: model loaded",
|
||||
"layers", cfg.NumHiddenLayers,
|
||||
"hidden", cfg.HiddenSize,
|
||||
"heads", cfg.NumAttentionHeads,
|
||||
"kv_heads", cfg.NumKeyValueHeads,
|
||||
"head_dim", cfg.HeadDim,
|
||||
"vocab", cfg.VocabSize,
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Forward runs the Qwen 3 forward pass.
|
||||
// Unlike Gemma, Qwen does NOT scale embeddings by sqrt(hidden_size).
|
||||
func (m *Qwen3Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *Qwen3DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
// Pre-attention norm → attention → residual add
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, cfg)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
// Pre-MLP norm → MLP → residual add
|
||||
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Qwen3Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Qwen3Config) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K RMS normalization (Qwen 3 has this)
|
||||
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
|
||||
// RoPE — single theta for all layers (no sliding window)
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update KV cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// GQA: repeat K/V heads to match Q heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
|
||||
func (m *Qwen3MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
gate := mlx.SiLU(m.GateProj.Forward(x))
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
|
||||
func (m *Qwen3Model) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
func (m *Qwen3Model) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *Qwen3Model) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
|
||||
// ModelType returns the architecture identifier.
|
||||
func (m *Qwen3Model) ModelType() string { return "qwen3" }
|
||||
|
||||
// ApplyLoRA wraps target projection layers with LoRA adapters.
|
||||
func (m *Qwen3Model) ApplyLoRA(cfg mlx.LoRAConfig) *mlx.LoRAAdapter {
|
||||
adapter := &mlx.LoRAAdapter{
|
||||
Layers: make(map[string]*mlx.LoRALinear),
|
||||
Config: cfg,
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
prefix := fmt.Sprintf("model.layers.%d.self_attn", i)
|
||||
for _, target := range cfg.TargetKeys {
|
||||
var proj *mlx.Linear
|
||||
switch target {
|
||||
case "q_proj":
|
||||
proj = layer.Attention.QProj
|
||||
case "k_proj":
|
||||
proj = layer.Attention.KProj
|
||||
case "v_proj":
|
||||
proj = layer.Attention.VProj
|
||||
case "o_proj":
|
||||
proj = layer.Attention.OProj
|
||||
}
|
||||
if proj != nil {
|
||||
lora := mlx.NewLoRALinear(proj, cfg.Rank, cfg.Alpha)
|
||||
proj.LoRA = lora
|
||||
adapter.Layers[prefix+"."+target] = lora
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return adapter
|
||||
}
|
||||
115
mlx/nn.go
115
mlx/nn.go
|
|
@ -1,115 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
||||
// For quantized models, set Scales/Biases/GroupSize/Bits to use QuantizedMatmul.
|
||||
// Set LoRA to inject a low-rank adapter (training only).
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Scales *Array `weight:"scales"`
|
||||
Biases *Array `weight:"biases"`
|
||||
Bias *Array `weight:"bias"`
|
||||
GroupSize int
|
||||
Bits int
|
||||
|
||||
LoRA *LoRALinear // Optional LoRA adapter — if set, Forward routes through it
|
||||
}
|
||||
|
||||
// NewLinear creates a dense Linear layer with optional bias.
|
||||
func NewLinear(weight, bias *Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// NewQuantizedLinear creates a quantized Linear layer.
|
||||
func NewQuantizedLinear(weight, scales, biases, bias *Array, groupSize, bits int) *Linear {
|
||||
return &Linear{
|
||||
Weight: weight,
|
||||
Scales: scales,
|
||||
Biases: biases,
|
||||
Bias: bias,
|
||||
GroupSize: groupSize,
|
||||
Bits: bits,
|
||||
}
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation.
|
||||
// If a LoRA adapter is attached, routes through it instead (base + low-rank delta).
|
||||
// Uses QuantizedMatmul when quantization parameters are present.
|
||||
func (l *Linear) Forward(x *Array) *Array {
|
||||
if l.LoRA != nil {
|
||||
return l.LoRA.Forward(x)
|
||||
}
|
||||
return l.baseForward(x)
|
||||
}
|
||||
|
||||
// baseForward is the raw linear transformation without LoRA.
|
||||
// Used internally by LoRALinear to avoid infinite recursion.
|
||||
func (l *Linear) baseForward(x *Array) *Array {
|
||||
var out *Array
|
||||
if l.Scales != nil {
|
||||
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
|
||||
} else {
|
||||
out = Matmul(x, Transpose(l.Weight))
|
||||
}
|
||||
if l.Bias != nil && l.Bias.Valid() {
|
||||
out = Add(out, l.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Embedding is a lookup table for token embeddings.
|
||||
// For quantized models, set Scales/Biases/GroupSize/Bits to dequantize before lookup.
|
||||
type Embedding struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Scales *Array `weight:"scales"`
|
||||
Biases *Array `weight:"biases"`
|
||||
GroupSize int
|
||||
Bits int
|
||||
}
|
||||
|
||||
// Forward looks up embeddings for the given token indices.
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
if e.Scales != nil {
|
||||
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
|
||||
return Take(w, indices, 0)
|
||||
}
|
||||
return Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// AsLinear returns a Linear layer using the embedding weights (for tied output).
|
||||
func (e *Embedding) AsLinear() *Linear {
|
||||
return &Linear{
|
||||
Weight: e.Weight,
|
||||
Scales: e.Scales,
|
||||
Biases: e.Biases,
|
||||
GroupSize: e.GroupSize,
|
||||
Bits: e.Bits,
|
||||
}
|
||||
}
|
||||
|
||||
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
|
||||
type RMSNormModule struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// Forward applies RMS normalization.
|
||||
func (r *RMSNormModule) Forward(x *Array, eps float32) *Array {
|
||||
return RMSNorm(x, r.Weight, eps)
|
||||
}
|
||||
|
||||
// RepeatKV repeats key/value heads for grouped-query attention.
|
||||
// Input shape: [B, num_kv_heads, L, D]
|
||||
// Output shape: [B, num_kv_heads * factor, L, D]
|
||||
func RepeatKV(x *Array, factor int32) *Array {
|
||||
if factor <= 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
B, H, L, D := shape[0], shape[1], shape[2], shape[3]
|
||||
|
||||
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
|
||||
expanded := ExpandDims(x, 2)
|
||||
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
|
||||
return Reshape(expanded, B, H*factor, L, D)
|
||||
}
|
||||
353
mlx/ops.go
353
mlx/ops.go
|
|
@ -1,353 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// --- Element-wise arithmetic ---
|
||||
|
||||
// Add returns element-wise a + b.
|
||||
func Add(a, b *Array) *Array {
|
||||
out := New("ADD", a, b)
|
||||
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AddScalar returns a + scalar (broadcast).
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Add(a, scalar)
|
||||
}
|
||||
|
||||
// Mul returns element-wise a * b.
|
||||
func Mul(a, b *Array) *Array {
|
||||
out := New("MUL", a, b)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// MulScalar returns a * scalar (broadcast).
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Mul(a, scalar)
|
||||
}
|
||||
|
||||
// Divide returns element-wise a / b.
|
||||
func Divide(a, b *Array) *Array {
|
||||
out := New("DIV", a, b)
|
||||
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Subtract returns element-wise a - b.
|
||||
func Subtract(a, b *Array) *Array {
|
||||
out := New("SUB", a, b)
|
||||
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Negative returns element-wise -a.
|
||||
func Negative(a *Array) *Array {
|
||||
out := New("NEG", a)
|
||||
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Math functions ---
|
||||
|
||||
// Exp returns element-wise exp(a).
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP", a)
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sigmoid returns element-wise 1/(1+exp(-a)).
|
||||
func Sigmoid(a *Array) *Array {
|
||||
out := New("SIGMOID", a)
|
||||
C.mlx_sigmoid(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
|
||||
func SiLU(a *Array) *Array {
|
||||
return Mul(a, Sigmoid(a))
|
||||
}
|
||||
|
||||
// Tanh returns element-wise tanh(a).
|
||||
func Tanh(a *Array) *Array {
|
||||
out := New("TANH", a)
|
||||
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sqrt returns element-wise sqrt(a).
|
||||
func Sqrt(a *Array) *Array {
|
||||
out := New("SQRT", a)
|
||||
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Rsqrt returns element-wise 1/sqrt(a).
|
||||
func Rsqrt(a *Array) *Array {
|
||||
out := New("RSQRT", a)
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Reciprocal returns element-wise 1/a.
|
||||
func Reciprocal(a *Array) *Array {
|
||||
out := New("RECIPROCAL", a)
|
||||
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Square returns element-wise a^2.
|
||||
func Square(a *Array) *Array {
|
||||
out := New("SQUARE", a)
|
||||
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Power returns element-wise a^b.
|
||||
func Power(a, b *Array) *Array {
|
||||
out := New("POWER", a, b)
|
||||
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Maximum returns element-wise max(a, b).
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAX", a, b)
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise min(a, b).
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MIN", a, b)
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Matrix operations ---
|
||||
|
||||
// Matmul returns the matrix product of a and b.
|
||||
func Matmul(a, b *Array) *Array {
|
||||
out := New("MATMUL", a, b)
|
||||
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// QuantizedMatmul performs quantized matrix multiplication.
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
||||
out := New("QMATMUL", x, w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("affine")
|
||||
defer C.free(unsafe.Pointer(mode))
|
||||
C.mlx_quantized_matmul(
|
||||
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
||||
C._Bool(transpose), gs, b, mode,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Reductions ---
|
||||
|
||||
// Softmax returns softmax along the last axis.
|
||||
func Softmax(a *Array) *Array {
|
||||
out := New("SOFTMAX", a)
|
||||
axis := []C.int{C.int(-1)}
|
||||
C.mlx_softmax_axes(&out.ctx, a.ctx, &axis[0], C.size_t(1), C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argmax returns the index of the maximum value along an axis.
|
||||
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", a)
|
||||
C.mlx_argmax_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TopK returns the top k values along the last axis.
|
||||
func TopK(a *Array, k int) *Array {
|
||||
out := New("TOPK", a)
|
||||
C.mlx_topk_axis(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sum reduces by summation along the given axis.
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("SUM", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_sum_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Mean reduces by averaging along the given axis.
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_mean_axes(&out.ctx, a.ctx, &axes[0], C.size_t(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Shape operations ---
|
||||
|
||||
// Reshape changes the shape of an array.
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
out := New("RESHAPE", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Transpose permutes dimensions. If no axes given, reverses all dims.
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
out := New("TRANSPOSE", a)
|
||||
if len(axes) == 0 {
|
||||
C.mlx_transpose(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
} else {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_transpose_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ExpandDims inserts a new axis at the given position.
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
out := New("EXPAND_DIMS", a)
|
||||
C.mlx_expand_dims(&out.ctx, a.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Squeeze removes dimensions of size 1.
|
||||
func Squeeze(a *Array, axes ...int) *Array {
|
||||
out := New("SQUEEZE", a)
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_squeeze_axes(&out.ctx, a.ctx, &cAxes[0], C.size_t(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Concatenate joins arrays along the given axis.
|
||||
func Concatenate(arrays []*Array, axis int) *Array {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
inputs := make([]*Array, len(arrays))
|
||||
for i, a := range arrays {
|
||||
C.mlx_vector_array_append_value(vector, a.ctx)
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
out := New("CONCAT", inputs...)
|
||||
C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// BroadcastTo broadcasts an array to the given shape.
|
||||
func BroadcastTo(a *Array, shape []int32) *Array {
|
||||
out := New("BROADCAST", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AsType casts an array to a different dtype.
|
||||
func AsType(a *Array, dtype DType) *Array {
|
||||
out := New("ASTYPE", a)
|
||||
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AsStrided creates a view with custom strides.
|
||||
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
||||
out := New("AS_STRIDED", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
cStrides := make([]C.int64_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.int64_t(s)
|
||||
}
|
||||
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.size_t(len(cShape)), &cStrides[0], C.size_t(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Take gathers elements from a along axis using indices.
|
||||
func Take(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE", a, indices)
|
||||
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Where selects elements from a or b based on condition.
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE", condition, a, b)
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argpartition partially sorts and returns indices for top-k selection.
|
||||
func Argpartition(a *Array, kth, axis int) *Array {
|
||||
out := New("ARGPARTITION", a)
|
||||
C.mlx_argpartition_axis(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Dequantize restores a quantized array to full precision.
|
||||
func Dequantize(w, scales, biases *Array, groupSize, bits int) *Array {
|
||||
out := New("DEQUANTIZE", w, scales, biases)
|
||||
gs := C.mlx_optional_int{value: C.int(groupSize), has_value: C._Bool(true)}
|
||||
b := C.mlx_optional_int{value: C.int(bits), has_value: C._Bool(true)}
|
||||
mode := C.CString("affine")
|
||||
defer C.free(unsafe.Pointer(mode))
|
||||
noDtype := C.mlx_optional_dtype{has_value: C._Bool(false)}
|
||||
C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, biases.ctx, gs, b, mode, noDtype, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// PutAlongAxis places values into array at indices along axis.
|
||||
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", a, indices, values)
|
||||
// Use scatter approach: src[indices] = values
|
||||
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TakeAlongAxis gathers elements from a along axis using indices.
|
||||
// Unlike Take, this uses the same number of dimensions for indices and input.
|
||||
func TakeAlongAxis(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE_ALONG_AXIS", a, indices)
|
||||
C.mlx_take_along_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// LogSumExp computes log(sum(exp(a))) along the given axis.
|
||||
// Numerically stable reduction for cross-entropy loss.
|
||||
func LogSumExp(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("LOGSUMEXP", a)
|
||||
C.mlx_logsumexp_axis(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
106
mlx/optim.go
106
mlx/optim.go
|
|
@ -1,106 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
import "math"
|
||||
|
||||
// AdamW implements the AdamW optimiser (Adam with decoupled weight decay).
|
||||
//
|
||||
// Update rule per parameter:
|
||||
//
|
||||
// m = beta1 * m + (1 - beta1) * grad
|
||||
// v = beta2 * v + (1 - beta2) * grad^2
|
||||
// m_hat = m / (1 - beta1^t)
|
||||
// v_hat = v / (1 - beta2^t)
|
||||
// param = param * (1 - lr * weight_decay) - lr * m_hat / (sqrt(v_hat) + eps)
|
||||
type AdamW struct {
|
||||
LR float64 // Learning rate (default 1e-5)
|
||||
Beta1 float64 // First moment decay (default 0.9)
|
||||
Beta2 float64 // Second moment decay (default 0.999)
|
||||
Eps float64 // Numerical stability (default 1e-8)
|
||||
WeightDecay float64 // Decoupled weight decay (default 0.01)
|
||||
|
||||
step int // Number of updates performed
|
||||
m []*Array // First moment estimates (positional, parallel to params)
|
||||
v []*Array // Second moment estimates (positional, parallel to params)
|
||||
}
|
||||
|
||||
// NewAdamW creates an AdamW optimiser with default hyperparameters.
|
||||
func NewAdamW(lr float64) *AdamW {
|
||||
return &AdamW{
|
||||
LR: lr,
|
||||
Beta1: 0.9,
|
||||
Beta2: 0.999,
|
||||
Eps: 1e-8,
|
||||
WeightDecay: 0.01,
|
||||
}
|
||||
}
|
||||
|
||||
// Step performs one optimisation step: updates params using gradients.
|
||||
// params and grads must be parallel slices of the same length.
|
||||
// Returns the updated parameter arrays (params are replaced in-place).
|
||||
func (o *AdamW) Step(params []*Array, grads []*Array) []*Array {
|
||||
o.step++
|
||||
|
||||
// Bias correction factors
|
||||
bc1 := 1.0 - math.Pow(o.Beta1, float64(o.step))
|
||||
bc2 := 1.0 - math.Pow(o.Beta2, float64(o.step))
|
||||
|
||||
updated := make([]*Array, len(params))
|
||||
|
||||
// Grow moment slices if needed (first call or param count increased)
|
||||
for len(o.m) < len(params) {
|
||||
o.m = append(o.m, nil)
|
||||
o.v = append(o.v, nil)
|
||||
}
|
||||
|
||||
for i, param := range params {
|
||||
grad := grads[i]
|
||||
|
||||
// Initialise moments on first use
|
||||
if o.m[i] == nil {
|
||||
shape := param.Shape()
|
||||
o.m[i] = Zeros(shape, param.Dtype())
|
||||
o.v[i] = Zeros(shape, param.Dtype())
|
||||
}
|
||||
|
||||
// m = beta1 * m + (1 - beta1) * grad
|
||||
m := Add(
|
||||
MulScalar(o.m[i], float32(o.Beta1)),
|
||||
MulScalar(grad, float32(1.0-o.Beta1)),
|
||||
)
|
||||
|
||||
// v = beta2 * v + (1 - beta2) * grad^2
|
||||
v := Add(
|
||||
MulScalar(o.v[i], float32(o.Beta2)),
|
||||
MulScalar(Square(grad), float32(1.0-o.Beta2)),
|
||||
)
|
||||
|
||||
// Bias-corrected estimates
|
||||
mHat := MulScalar(m, float32(1.0/bc1))
|
||||
vHat := MulScalar(v, float32(1.0/bc2))
|
||||
|
||||
// Weight decay: param = param * (1 - lr * weight_decay)
|
||||
decayed := MulScalar(param, float32(1.0-o.LR*o.WeightDecay))
|
||||
|
||||
// Update: param = decayed - lr * m_hat / (sqrt(v_hat) + eps)
|
||||
denom := AddScalar(Sqrt(vHat), float32(o.Eps))
|
||||
step := MulScalar(Divide(mHat, denom), float32(o.LR))
|
||||
newParam := Subtract(decayed, step)
|
||||
|
||||
// Store updated moments
|
||||
o.m[i] = m
|
||||
o.v[i] = v
|
||||
|
||||
updated[i] = newParam
|
||||
}
|
||||
|
||||
return updated
|
||||
}
|
||||
|
||||
// Reset clears the optimiser state (moments and step counter).
|
||||
func (o *AdamW) Reset() {
|
||||
o.step = 0
|
||||
o.m = nil
|
||||
o.v = nil
|
||||
}
|
||||
|
|
@ -1,175 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdamW_BasicStep(t *testing.T) {
|
||||
// Simple test: minimise f(x) = x^2, starting at x=10
|
||||
x := FromValue(float32(10.0))
|
||||
Materialize(x)
|
||||
|
||||
opt := NewAdamW(0.1)
|
||||
|
||||
for i := 0; i < 300; i++ {
|
||||
// Gradient of x^2 is 2x
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
p := inputs[0]
|
||||
return []*Array{Mul(p, p)}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn)
|
||||
_, grads, err := grad.Apply(x)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: grad failed: %v", i, err)
|
||||
}
|
||||
|
||||
updated := opt.Step([]*Array{x}, grads)
|
||||
x = updated[0]
|
||||
Materialize(x)
|
||||
}
|
||||
|
||||
final := x.Float()
|
||||
if math.Abs(final) > 0.5 {
|
||||
t.Errorf("after 300 steps, x = %f, want near 0", final)
|
||||
}
|
||||
t.Logf("final x = %f (started at 10.0)", final)
|
||||
}
|
||||
|
||||
func TestAdamW_MultiParam(t *testing.T) {
|
||||
// Minimise f(x, y) = x^2 + y^2
|
||||
x := FromValue(float32(5.0))
|
||||
y := FromValue(float32(-3.0))
|
||||
Materialize(x, y)
|
||||
|
||||
opt := NewAdamW(0.1)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
return []*Array{Add(Mul(inputs[0], inputs[0]), Mul(inputs[1], inputs[1]))}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn, 0, 1)
|
||||
_, grads, err := grad.Apply(x, y)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d failed: %v", i, err)
|
||||
}
|
||||
|
||||
updated := opt.Step([]*Array{x, y}, grads)
|
||||
x = updated[0]
|
||||
y = updated[1]
|
||||
Materialize(x, y)
|
||||
}
|
||||
|
||||
xFinal := x.Float()
|
||||
yFinal := y.Float()
|
||||
if math.Abs(xFinal) > 0.1 || math.Abs(yFinal) > 0.1 {
|
||||
t.Errorf("x=%f, y=%f, want both near 0", xFinal, yFinal)
|
||||
}
|
||||
t.Logf("final x=%f, y=%f", xFinal, yFinal)
|
||||
}
|
||||
|
||||
func TestAdamW_WeightDecay(t *testing.T) {
|
||||
// With large weight decay and zero gradient, param should decay toward 0
|
||||
x := FromValue(float32(10.0))
|
||||
Materialize(x)
|
||||
|
||||
opt := NewAdamW(0.01)
|
||||
opt.WeightDecay = 0.5 // aggressive decay
|
||||
|
||||
zeroGrad := FromValue(float32(0.0))
|
||||
Materialize(zeroGrad)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
updated := opt.Step([]*Array{x}, []*Array{zeroGrad})
|
||||
x = updated[0]
|
||||
Materialize(x)
|
||||
}
|
||||
|
||||
final := x.Float()
|
||||
if final >= 10.0 {
|
||||
t.Errorf("x = %f, should have decayed from 10.0", final)
|
||||
}
|
||||
if final <= 0 {
|
||||
t.Errorf("x = %f, decayed too much", final)
|
||||
}
|
||||
t.Logf("after 10 steps with weight_decay=0.5: x = %f (started at 10.0)", final)
|
||||
}
|
||||
|
||||
func TestAdamW_Reset(t *testing.T) {
|
||||
opt := NewAdamW(0.01)
|
||||
|
||||
x := FromValue(float32(5.0))
|
||||
grad := FromValue(float32(1.0))
|
||||
Materialize(x, grad)
|
||||
|
||||
opt.Step([]*Array{x}, []*Array{grad})
|
||||
if opt.step != 1 {
|
||||
t.Errorf("step = %d, want 1", opt.step)
|
||||
}
|
||||
|
||||
opt.Reset()
|
||||
if opt.step != 0 {
|
||||
t.Errorf("after reset, step = %d, want 0", opt.step)
|
||||
}
|
||||
if opt.m != nil {
|
||||
t.Error("after reset, moments should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdamW_WithLoRA(t *testing.T) {
|
||||
// End-to-end: create LoRA layer, compute gradients, update with AdamW
|
||||
w := RandomNormal(0, 0.1, []int32{4, 8}, DTypeFloat32)
|
||||
Materialize(w)
|
||||
base := NewLinear(w, nil)
|
||||
|
||||
lora := NewLoRALinear(base, 4, 8.0)
|
||||
opt := NewAdamW(0.001)
|
||||
|
||||
x := RandomNormal(0, 1, []int32{1, 2, 8}, DTypeFloat32)
|
||||
target := RandomNormal(0, 1, []int32{1, 2, 4}, DTypeFloat32)
|
||||
Materialize(x, target)
|
||||
|
||||
var initialLoss, finalLoss float64
|
||||
|
||||
for step := 0; step < 50; step++ {
|
||||
lossFn := func(inputs []*Array) []*Array {
|
||||
lora.A = inputs[0]
|
||||
lora.B = inputs[1]
|
||||
pred := lora.Forward(x)
|
||||
return []*Array{MSELoss(pred, target)}
|
||||
}
|
||||
|
||||
grad := ValueAndGrad(lossFn, 0, 1)
|
||||
values, grads, err := grad.Apply(lora.A, lora.B)
|
||||
grad.Free()
|
||||
if err != nil {
|
||||
t.Fatalf("step %d failed: %v", step, err)
|
||||
}
|
||||
|
||||
Materialize(append(values, grads...)...)
|
||||
|
||||
loss := values[0].Float()
|
||||
if step == 0 {
|
||||
initialLoss = loss
|
||||
}
|
||||
if step == 49 {
|
||||
finalLoss = loss
|
||||
}
|
||||
|
||||
updated := opt.Step([]*Array{lora.A, lora.B}, grads)
|
||||
lora.A = updated[0]
|
||||
lora.B = updated[1]
|
||||
Materialize(lora.A, lora.B)
|
||||
}
|
||||
|
||||
t.Logf("loss: %.6f -> %.6f", initialLoss, finalLoss)
|
||||
if finalLoss >= initialLoss {
|
||||
t.Errorf("loss did not decrease: %f -> %f", initialLoss, finalLoss)
|
||||
}
|
||||
}
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// RandomCategorical samples from a categorical distribution defined by logprobs.
|
||||
// Returns indices sampled according to the log-probability distribution along the last axis.
|
||||
func RandomCategorical(logprobs *Array) *Array {
|
||||
out := New("RANDOM_CATEGORICAL", logprobs)
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_categorical(
|
||||
&out.ctx,
|
||||
logprobs.ctx,
|
||||
C.int(-1), // axis
|
||||
key, // null key = use default RNG
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// RandomUniform generates uniform random values in [low, high).
|
||||
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
|
||||
out := New("RANDOM_UNIFORM")
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
lo := FromValue(low)
|
||||
hi := FromValue(high)
|
||||
key := C.mlx_array_new()
|
||||
defer C.mlx_array_free(key)
|
||||
C.mlx_random_uniform(
|
||||
&out.ctx,
|
||||
lo.ctx, hi.ctx,
|
||||
&cShape[0], C.size_t(len(cShape)),
|
||||
C.mlx_dtype(dtype),
|
||||
key,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package sample provides composable token sampling strategies.
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"forge.lthn.ai/core/go-ai/mlx"
|
||||
)
|
||||
|
||||
// Sampler transforms logits into a sampled token index.
|
||||
type Sampler interface {
|
||||
Sample(logits *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// New creates a composable sampler chain from the given parameters.
|
||||
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
|
||||
func New(temp, topP, minP float32, topK int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
var samplers []Sampler
|
||||
if topP > 0 && topP < 1 {
|
||||
samplers = append(samplers, TopP(topP))
|
||||
}
|
||||
if minP > 0 {
|
||||
samplers = append(samplers, MinPSampler(minP))
|
||||
}
|
||||
if topK > 0 {
|
||||
samplers = append(samplers, TopKSampler(topK))
|
||||
}
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
// chain applies a sequence of samplers, then samples from the result.
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, s := range c {
|
||||
logits = s.Sample(logits)
|
||||
}
|
||||
// Final categorical sample from log-probabilities
|
||||
return mlx.RandomCategorical(logits)
|
||||
}
|
||||
|
||||
// greedy returns the argmax token.
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.Argmax(logits, -1, false)
|
||||
}
|
||||
|
||||
// Temperature scales logits by 1/temp.
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.MulScalar(logits, 1.0/float32(t))
|
||||
}
|
||||
|
||||
// TopKSampler masks all but the top-k logits.
|
||||
type TopKSampler int
|
||||
|
||||
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
neg := mlx.Negative(logits)
|
||||
mask := mlx.Argpartition(neg, int(k)-1, -1)
|
||||
// Slice the indices beyond top-k
|
||||
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
|
||||
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
// TopP implements nucleus sampling (cumulative probability threshold).
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// TODO: full nucleus sampling requires cumsum which mlx-c doesn't expose directly.
|
||||
// For now, pass through. TopK + Temperature covers most use cases.
|
||||
return logits
|
||||
}
|
||||
|
||||
// MinPSampler masks tokens below min_p * max_prob.
|
||||
type MinPSampler float32
|
||||
|
||||
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// For now, pass through — MinP is an optimization over TopP.
|
||||
// Full implementation requires finding max prob and masking below threshold.
|
||||
return logits
|
||||
}
|
||||
63
mlx/slice.go
63
mlx/slice.go
|
|
@ -1,63 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// Slice extracts a sub-array using start and end indices for each dimension.
|
||||
// starts and ends must have the same length as the array's dimensions.
|
||||
func Slice(a *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE", a)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
cStarts[i] = C.int(starts[i])
|
||||
cEnds[i] = C.int(ends[i])
|
||||
}
|
||||
strides := make([]C.int, len(starts))
|
||||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SliceAxis extracts a sub-array along a single axis.
|
||||
func SliceAxis(a *Array, axis int, start, end int32) *Array {
|
||||
// Build full slice parameters
|
||||
ndim := a.NumDims()
|
||||
starts := make([]int32, ndim)
|
||||
ends := make([]int32, ndim)
|
||||
for i := 0; i < ndim; i++ {
|
||||
starts[i] = 0
|
||||
ends[i] = int32(a.Dim(i))
|
||||
}
|
||||
ax := axis
|
||||
if ax < 0 {
|
||||
ax = ndim + ax
|
||||
}
|
||||
starts[ax] = start
|
||||
ends[ax] = end
|
||||
return Slice(a, starts, ends)
|
||||
}
|
||||
|
||||
// SliceUpdateInplace updates a slice of the array in-place.
|
||||
// This is critical for KV cache updates.
|
||||
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE_UPDATE", a, update)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
cStarts[i] = C.int(starts[i])
|
||||
cEnds[i] = C.int(ends[i])
|
||||
}
|
||||
strides := make([]C.int, len(starts))
|
||||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.size_t(len(cStarts)), &cEnds[0], C.size_t(len(cEnds)), &strides[0], C.size_t(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
|
@ -1,79 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "sync"
|
||||
|
||||
// Stream wraps an mlx_stream handle for dispatching operations.
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
var (
|
||||
defaultStream *Stream
|
||||
defaultStreamOnce sync.Once
|
||||
)
|
||||
|
||||
// DefaultStream returns the default GPU stream, creating it on first use.
|
||||
func DefaultStream() *Stream {
|
||||
defaultStreamOnce.Do(func() {
|
||||
Init()
|
||||
defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
})
|
||||
return defaultStream
|
||||
}
|
||||
|
||||
// DefaultGPUStream returns a new GPU stream.
|
||||
func DefaultGPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
}
|
||||
|
||||
// DefaultCPUStream returns a new CPU stream.
|
||||
func DefaultCPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
|
||||
}
|
||||
|
||||
// Synchronize waits for all operations on the stream to complete.
|
||||
func Synchronize(s *Stream) {
|
||||
C.mlx_synchronize(s.ctx)
|
||||
}
|
||||
|
||||
// SetMemoryLimit sets the Metal memory limit. Returns the previous limit.
|
||||
func SetMemoryLimit(limit uint64) uint64 {
|
||||
var prev C.size_t
|
||||
C.mlx_set_memory_limit(&prev, C.size_t(limit))
|
||||
return uint64(prev)
|
||||
}
|
||||
|
||||
// SetCacheLimit sets the Metal cache limit. Returns the previous limit.
|
||||
func SetCacheLimit(limit uint64) uint64 {
|
||||
var prev C.size_t
|
||||
C.mlx_set_cache_limit(&prev, C.size_t(limit))
|
||||
return uint64(prev)
|
||||
}
|
||||
|
||||
// GetActiveMemory returns the current Metal memory usage in bytes.
|
||||
func GetActiveMemory() uint64 {
|
||||
var mem C.size_t
|
||||
C.mlx_get_active_memory(&mem)
|
||||
return uint64(mem)
|
||||
}
|
||||
|
||||
// GetPeakMemory returns the peak Metal memory usage in bytes.
|
||||
func GetPeakMemory() uint64 {
|
||||
var mem C.size_t
|
||||
C.mlx_get_peak_memory(&mem)
|
||||
return uint64(mem)
|
||||
}
|
||||
|
||||
// ClearCache releases Metal memory held in the MLX allocator cache.
|
||||
func ClearCache() {
|
||||
C.mlx_clear_cache()
|
||||
}
|
||||
|
|
@ -1,324 +0,0 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
// Package tokenizer provides BPE tokenization for transformer models.
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tokenizer handles text-to-token and token-to-text conversion.
|
||||
type Tokenizer struct {
|
||||
vocab map[string]int32
|
||||
invVocab map[int32]string
|
||||
merges []mergePair
|
||||
special map[string]int32
|
||||
|
||||
bosToken int32
|
||||
eosToken int32
|
||||
|
||||
// GPT-2 byte-level BPE support (used by Qwen, GPT, Llama, etc.)
|
||||
isGPT2BPE bool
|
||||
gpt2Decoder map[rune]byte // Unicode char → original byte
|
||||
gpt2Encoder map[byte]rune // original byte → Unicode char
|
||||
}
|
||||
|
||||
type mergePair struct {
|
||||
a, b string
|
||||
rank int
|
||||
}
|
||||
|
||||
// tokenizerJSON is the HuggingFace tokenizer.json format.
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab json.RawMessage `json:"vocab"`
|
||||
Merges json.RawMessage `json:"merges"`
|
||||
ByteFallback bool `json:"byte_fallback"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// Load reads a tokenizer.json file and creates a Tokenizer.
|
||||
func Load(path string) (*Tokenizer, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: parse: %w", err)
|
||||
}
|
||||
|
||||
t := &Tokenizer{
|
||||
vocab: make(map[string]int32),
|
||||
invVocab: make(map[int32]string),
|
||||
special: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Parse vocab
|
||||
var vocab map[string]int32
|
||||
if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: parse vocab: %w", err)
|
||||
}
|
||||
t.vocab = vocab
|
||||
for k, v := range vocab {
|
||||
t.invVocab[v] = k
|
||||
}
|
||||
|
||||
// Parse merges — supports both ["a b", ...] and [["a","b"], ...] formats
|
||||
if len(tj.Model.Merges) > 0 {
|
||||
var stringMerges []string
|
||||
if err := json.Unmarshal(tj.Model.Merges, &stringMerges); err == nil {
|
||||
for rank, merge := range stringMerges {
|
||||
parts := strings.SplitN(merge, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var arrayMerges [][]string
|
||||
if err := json.Unmarshal(tj.Model.Merges, &arrayMerges); err == nil {
|
||||
for rank, pair := range arrayMerges {
|
||||
if len(pair) == 2 {
|
||||
t.merges = append(t.merges, mergePair{a: pair[0], b: pair[1], rank: rank})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse special tokens
|
||||
for _, tok := range tj.AddedTokens {
|
||||
if tok.Special {
|
||||
t.special[tok.Content] = tok.ID
|
||||
}
|
||||
t.vocab[tok.Content] = tok.ID
|
||||
t.invVocab[tok.ID] = tok.Content
|
||||
}
|
||||
|
||||
// Detect GPT-2 byte-level BPE (Qwen, GPT, Llama use Ġ for space)
|
||||
if _, ok := t.vocab["Ġ"]; ok {
|
||||
t.isGPT2BPE = true
|
||||
t.gpt2Decoder, t.gpt2Encoder = buildGPT2ByteMaps()
|
||||
}
|
||||
|
||||
// Set BOS/EOS — detect model family from special tokens
|
||||
if id, ok := t.special["<bos>"]; ok {
|
||||
t.bosToken = id
|
||||
}
|
||||
if id, ok := t.special["<eos>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
// Gemma: <end_of_turn> is the generation stop token
|
||||
if id, ok := t.special["<end_of_turn>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
// Qwen3: <|im_end|> is the generation stop token
|
||||
if id, ok := t.special["<|im_end|>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
// Qwen3 BOS: <|im_start|>
|
||||
if id, ok := t.special["<|im_start|>"]; ok {
|
||||
t.bosToken = id
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// buildGPT2ByteMaps creates the GPT-2 byte-level BPE encoding/decoding maps.
|
||||
// GPT-2 maps all 256 bytes to printable Unicode characters to avoid control chars
|
||||
// in the vocabulary. Printable ASCII + Latin-1 Supplement map to themselves;
|
||||
// everything else (0-32, 127-160, 173) maps to U+0100 onwards.
|
||||
func buildGPT2ByteMaps() (decoder map[rune]byte, encoder map[byte]rune) {
|
||||
encoder = make(map[byte]rune, 256)
|
||||
decoder = make(map[rune]byte, 256)
|
||||
|
||||
// Self-mapping ranges: printable ASCII + Latin-1 Supplement
|
||||
// Use int loop variable to avoid byte overflow at 255.
|
||||
selfMap := func(lo, hi int) {
|
||||
for b := lo; b <= hi; b++ {
|
||||
encoder[byte(b)] = rune(b)
|
||||
decoder[rune(b)] = byte(b)
|
||||
}
|
||||
}
|
||||
selfMap(33, 126) // ! through ~
|
||||
selfMap(161, 172) // ¡ through ¬
|
||||
selfMap(174, 255) // ® through ÿ
|
||||
|
||||
// Non-self-mapping: control chars, space, DEL, and gaps
|
||||
n := 0
|
||||
for b := 0; b < 256; b++ {
|
||||
if _, ok := encoder[byte(b)]; !ok {
|
||||
r := rune(256 + n)
|
||||
encoder[byte(b)] = r
|
||||
decoder[r] = byte(b)
|
||||
n++
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Encode converts text to token IDs. Prepends BOS token.
|
||||
func (t *Tokenizer) Encode(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.encodeGPT2(text)
|
||||
}
|
||||
|
||||
// SentencePiece style encoding
|
||||
remaining := text
|
||||
for remaining != "" {
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
tokens = append(tokens, id)
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r := []rune(remaining)
|
||||
ch := "▁" + string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
tokens = append(tokens, id)
|
||||
} else if id, ok := t.vocab[string(r[0])]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// encodeGPT2 encodes text using GPT-2 byte-level BPE.
|
||||
func (t *Tokenizer) encodeGPT2(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// Convert text bytes to GPT-2 Unicode representation
|
||||
var encoded strings.Builder
|
||||
for _, b := range []byte(text) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encoded.WriteRune(r)
|
||||
}
|
||||
}
|
||||
gpt2Text := encoded.String()
|
||||
|
||||
// Scan for special tokens and regular text
|
||||
remaining := gpt2Text
|
||||
for remaining != "" {
|
||||
// Check special tokens (these are stored as-is, not byte-encoded)
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
// Special tokens in GPT-2 tokenizers are stored in their original form
|
||||
// Convert the special token to GPT-2 encoding for matching
|
||||
var encTok strings.Builder
|
||||
for _, b := range []byte(tok) {
|
||||
if r, ok := t.gpt2Encoder[b]; ok {
|
||||
encTok.WriteRune(r)
|
||||
}
|
||||
}
|
||||
encStr := encTok.String()
|
||||
if strings.HasPrefix(remaining, encStr) {
|
||||
tokens = append(tokens, id)
|
||||
remaining = remaining[len(encStr):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Character-by-character lookup (simplified BPE)
|
||||
r := []rune(remaining)
|
||||
ch := string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to text.
|
||||
// For full-sequence decoding, the SentencePiece leading space is stripped.
|
||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if text, ok := t.invVocab[id]; ok {
|
||||
// Skip special tokens in decode output
|
||||
if _, isSpecial := t.special[text]; isSpecial {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
raw := sb.String()
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.decodeGPT2Bytes(raw)
|
||||
}
|
||||
|
||||
// SentencePiece style
|
||||
result := strings.ReplaceAll(raw, "▁", " ")
|
||||
if strings.HasPrefix(result, " ") {
|
||||
result = result[1:]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// DecodeToken converts a single token ID to text for streaming.
|
||||
// Unlike Decode, it preserves the leading space (word boundary) so that
|
||||
// token-by-token output maintains correct spacing between words.
|
||||
func (t *Tokenizer) DecodeToken(id int32) string {
|
||||
text, ok := t.invVocab[id]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if _, isSpecial := t.special[text]; isSpecial {
|
||||
return ""
|
||||
}
|
||||
|
||||
if t.isGPT2BPE {
|
||||
return t.decodeGPT2Bytes(text)
|
||||
}
|
||||
|
||||
// SentencePiece: replace ▁ with space but keep it (it's the word boundary)
|
||||
return strings.ReplaceAll(text, "▁", " ")
|
||||
}
|
||||
|
||||
// decodeGPT2Bytes converts GPT-2 byte-level BPE Unicode back to real bytes.
|
||||
func (t *Tokenizer) decodeGPT2Bytes(s string) string {
|
||||
var buf []byte
|
||||
for _, r := range s {
|
||||
if b, ok := t.gpt2Decoder[r]; ok {
|
||||
buf = append(buf, b)
|
||||
} else {
|
||||
// Non-mapped runes pass through as UTF-8
|
||||
buf = append(buf, []byte(string(r))...)
|
||||
}
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// BOSToken returns the beginning-of-sequence token ID.
|
||||
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
||||
|
||||
// EOSToken returns the end-of-sequence token ID.
|
||||
func (t *Tokenizer) EOSToken() int32 { return t.eosToken }
|
||||
|
||||
// FormatGemmaPrompt applies the Gemma 3 chat template.
|
||||
func FormatGemmaPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue