feat: add native MLX backend for Apple Silicon inference (pkg/mlx)

CGo wrapper for mlx-c providing zero-Python Metal GPU inference.
Includes Gemma 3 model architecture, BPE tokenizer, KV cache,
composable sampling, and OpenAI-compatible serve command.

Build-tagged (darwin && arm64 && mlx) with stubs for cross-platform.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-16 01:19:04 +00:00 committed by Snider
parent 548256312d
commit bc28aad526
20 changed files with 2398 additions and 0 deletions

View file

@ -10,6 +10,7 @@
// - core ml convert: Convert MLX LoRA adapter to PEFT format
// - core ml agent: Run the scoring agent daemon
// - core ml worker: Run a distributed worker node
// - core ml serve: Start OpenAI-compatible inference server
package ml
import (
@ -38,6 +39,7 @@ func AddMLCommands(root *cli.Command) {
mlCmd.AddCommand(convertCmd)
mlCmd.AddCommand(agentCmd)
mlCmd.AddCommand(workerCmd)
mlCmd.AddCommand(serveCmd)
root.AddCommand(mlCmd)
}

View file

@ -0,0 +1,174 @@
package ml
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"forge.lthn.ai/core/cli/pkg/cli"
"forge.lthn.ai/core/cli/pkg/ml"
)
var serveCmd = &cli.Command{
Use: "serve",
Short: "Start OpenAI-compatible inference server",
Long: "Starts an HTTP server serving /v1/completions and /v1/chat/completions using the configured ML backend.",
RunE: runServe,
}
var (
serveBind string
serveModelPath string
)
func init() {
serveCmd.Flags().StringVar(&serveBind, "bind", "0.0.0.0:8090", "Address to bind")
serveCmd.Flags().StringVar(&serveModelPath, "model-path", "", "Path to model directory (for mlx backend)")
}
type completionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
}
type completionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []completionChoice `json:"choices"`
Usage usageInfo `json:"usage"`
}
type completionChoice struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
type chatRequest struct {
Model string `json:"model"`
Messages []ml.Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Temperature float64 `json:"temperature"`
}
type chatResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []chatChoice `json:"choices"`
}
type chatChoice struct {
Message ml.Message `json:"message"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
type usageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
func runServe(cmd *cli.Command, args []string) error {
// Create a backend — use HTTP backend pointing to configured API URL.
// On macOS with MLX build tag, this will use the native MLX backend instead.
backend := ml.NewHTTPBackend(apiURL, modelName)
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req completionRequest
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), 400)
return
}
opts := ml.GenOpts{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Model: req.Model,
}
text, err := backend.Generate(r.Context(), req.Prompt, opts)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
resp := completionResponse{
ID: fmt.Sprintf("cmpl-%d", time.Now().UnixNano()),
Object: "text_completion",
Created: time.Now().Unix(),
Model: backend.Name(),
Choices: []completionChoice{{Text: text, FinishReason: "stop"}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
mux.HandleFunc("POST /v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
var req chatRequest
if err := json.Unmarshal(body, &req); err != nil {
http.Error(w, err.Error(), 400)
return
}
opts := ml.GenOpts{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
Model: req.Model,
}
text, err := backend.Chat(r.Context(), req.Messages, opts)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
resp := chatResponse{
ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
Object: "chat.completion",
Created: time.Now().Unix(),
Model: backend.Name(),
Choices: []chatChoice{{
Message: ml.Message{Role: "assistant", Content: text},
FinishReason: "stop",
}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
mux.HandleFunc("GET /v1/models", func(w http.ResponseWriter, r *http.Request) {
resp := struct {
Object string `json:"object"`
Data []struct {
ID string `json:"id"`
} `json:"data"`
}{
Object: "list",
Data: []struct {
ID string `json:"id"`
}{{ID: backend.Name()}},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
})
slog.Info("ml serve: starting", "bind", serveBind, "backend", backend.Name())
fmt.Printf("Serving on http://%s\n", serveBind)
return http.ListenAndServe(serveBind, mux)
}

169
pkg/ml/backend_mlx.go Normal file
View file

@ -0,0 +1,169 @@
//go:build darwin && arm64 && mlx
package ml
import (
"context"
"fmt"
"log/slog"
"sync"
"forge.lthn.ai/core/cli/pkg/mlx"
"forge.lthn.ai/core/cli/pkg/mlx/cache"
"forge.lthn.ai/core/cli/pkg/mlx/model"
"forge.lthn.ai/core/cli/pkg/mlx/sample"
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
)
// MLXBackend implements Backend for native Metal inference via mlx-c.
type MLXBackend struct {
model *model.GemmaModel
tok *tokenizer.Tokenizer
caches []cache.Cache
sampler sample.Sampler
mu sync.Mutex
}
// NewMLXBackend loads a model from a safetensors directory and creates
// a native Metal inference backend.
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
if !mlx.MetalAvailable() {
return nil, fmt.Errorf("mlx: Metal GPU not available")
}
slog.Info("mlx: loading model", "path", modelPath)
m, err := model.LoadGemma3(modelPath)
if err != nil {
return nil, fmt.Errorf("mlx: load model: %w", err)
}
slog.Info("mlx: model loaded",
"layers", m.NumLayers(),
"memory_mb", mlx.GetActiveMemory()/1024/1024,
)
return &MLXBackend{
model: m,
tok: m.Tokenizer(),
caches: m.NewCache(),
sampler: sample.New(0.1, 0, 0, 0), // default low temp
}, nil
}
// Generate produces text from a prompt using native Metal inference.
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
b.mu.Lock()
defer b.mu.Unlock()
// Reset caches for new generation
for _, c := range b.caches {
c.Reset()
}
// Set up sampler based on opts
temp := float32(opts.Temperature)
if temp == 0 {
temp = 0.1
}
sampler := sample.New(temp, 0, 0, 0)
// Tokenize
formatted := tokenizer.FormatGemmaPrompt(prompt)
tokens := b.tok.Encode(formatted)
input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens
if maxTokens == 0 {
maxTokens = 2048
}
// Generation loop
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
return b.tok.Decode(output), ctx.Err()
default:
}
logits := b.model.Forward(input, b.caches)
next := sampler.Sample(logits)
mlx.Materialize(next)
nextToken := int32(next.Int())
if nextToken == b.tok.EOSToken() {
break
}
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
}
return b.tok.Decode(output), nil
}
// Chat formats messages and generates a response.
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
// Format as Gemma chat
var prompt string
for _, msg := range messages {
switch msg.Role {
case "user":
prompt += fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content)
case "assistant":
prompt += fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content)
case "system":
prompt += fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content)
}
}
prompt += "<start_of_turn>model\n"
// Use raw prompt (already formatted)
b.mu.Lock()
defer b.mu.Unlock()
for _, c := range b.caches {
c.Reset()
}
temp := float32(opts.Temperature)
if temp == 0 {
temp = 0.1
}
sampler := sample.New(temp, 0, 0, 0)
tokens := b.tok.Encode(prompt)
input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens
if maxTokens == 0 {
maxTokens = 2048
}
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
return b.tok.Decode(output), ctx.Err()
default:
}
logits := b.model.Forward(input, b.caches)
next := sampler.Sample(logits)
mlx.Materialize(next)
nextToken := int32(next.Int())
if nextToken == b.tok.EOSToken() {
break
}
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
}
return b.tok.Decode(output), nil
}
// Name returns the backend identifier.
func (b *MLXBackend) Name() string { return "mlx" }
// Available reports whether Metal GPU is ready.
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }

26
pkg/mlx/CMakeLists.txt Normal file
View file

@ -0,0 +1,26 @@
cmake_minimum_required(VERSION 3.5)
project(mlx)
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
endif()
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
set(CMAKE_INSTALL_RPATH "@loader_path")
include(FetchContent)
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG ${MLX_C_GIT_TAG}
)
FetchContent_MakeAvailable(mlx-c)

273
pkg/mlx/array.go Normal file
View file

@ -0,0 +1,273 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"encoding/binary"
"reflect"
"strings"
"unsafe"
)
type tensorDesc struct {
name string
inputs []*Array
numRefs int
}
// Array wraps an mlx_array handle with reference-counted memory management.
type Array struct {
ctx C.mlx_array
desc tensorDesc
}
// New creates a named Array tracking its input dependencies for cleanup.
func New(name string, inputs ...*Array) *Array {
t := &Array{
desc: tensorDesc{
name: name,
inputs: inputs,
},
}
for _, input := range inputs {
if input != nil {
input.desc.numRefs++
}
}
return t
}
type scalarTypes interface {
~bool | ~int | ~float32 | ~float64 | ~complex64
}
// FromValue creates a scalar Array from a Go value.
func FromValue[T scalarTypes](t T) *Array {
Init()
tt := New("")
switch v := any(t).(type) {
case bool:
tt.ctx = C.mlx_array_new_bool(C.bool(v))
case int:
tt.ctx = C.mlx_array_new_int(C.int(v))
case float32:
tt.ctx = C.mlx_array_new_float32(C.float(v))
case float64:
tt.ctx = C.mlx_array_new_float64(C.double(v))
case complex64:
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
default:
panic("mlx: unsupported scalar type")
}
return tt
}
type arrayTypes interface {
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
~int8 | ~int16 | ~int32 | ~int64 |
~float32 | ~float64 |
~complex64
}
// FromValues creates an Array from a Go slice with the given shape.
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
Init()
if len(shape) == 0 {
panic("mlx: shape required for non-scalar tensors")
}
cShape := make([]C.int, len(shape))
for i := range shape {
cShape[i] = C.int(shape[i])
}
var dtype DType
switch reflect.TypeOf(s).Elem().Kind() {
case reflect.Bool:
dtype = DTypeBool
case reflect.Uint8:
dtype = DTypeUint8
case reflect.Uint16:
dtype = DTypeUint16
case reflect.Uint32:
dtype = DTypeUint32
case reflect.Uint64:
dtype = DTypeUint64
case reflect.Int8:
dtype = DTypeInt8
case reflect.Int16:
dtype = DTypeInt16
case reflect.Int32:
dtype = DTypeInt32
case reflect.Int64:
dtype = DTypeInt64
case reflect.Float32:
dtype = DTypeFloat32
case reflect.Float64:
dtype = DTypeFloat64
case reflect.Complex64:
dtype = DTypeComplex64
default:
panic("mlx: unsupported element type")
}
bts := make([]byte, binary.Size(s))
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
panic(err)
}
tt := New("")
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
return tt
}
// Zeros creates a zero-filled Array with the given shape and dtype.
func Zeros(shape []int32, dtype DType) *Array {
Init()
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
tt := New("ZEROS")
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
return tt
}
// Set replaces this array's value with another, updating ref tracking.
func (t *Array) Set(other *Array) {
Free(t.desc.inputs...)
other.desc.numRefs++
t.desc.inputs = []*Array{other}
C.mlx_array_set(&t.ctx, other.ctx)
}
// Clone creates a copy of this array sharing the same data.
func (t *Array) Clone() *Array {
tt := New(t.desc.name, t.desc.inputs...)
C.mlx_array_set(&tt.ctx, t.ctx)
return tt
}
// Valid reports whether this Array has a non-nil mlx handle.
func (t *Array) Valid() bool {
return t.ctx.ctx != nil
}
// String returns a human-readable representation of the array.
func (t *Array) String() string {
str := C.mlx_string_new()
defer C.mlx_string_free(str)
C.mlx_array_tostring(&str, t.ctx)
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
}
// Shape returns the dimensions as int32 slice.
func (t *Array) Shape() []int32 {
dims := make([]int32, t.NumDims())
for i := range dims {
dims[i] = int32(t.Dim(i))
}
return dims
}
// Size returns the total number of elements.
func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) }
// NumBytes returns the total byte size.
func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) }
// NumDims returns the number of dimensions.
func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) }
// Dim returns the size of dimension i.
func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) }
// Dims returns all dimensions as int slice.
func (t Array) Dims() []int {
dims := make([]int, t.NumDims())
for i := range dims {
dims[i] = t.Dim(i)
}
return dims
}
// Dtype returns the array's data type.
func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) }
// Int extracts a scalar int64 value.
func (t Array) Int() int {
var item C.int64_t
C.mlx_array_item_int64(&item, t.ctx)
return int(item)
}
// Float extracts a scalar float64 value.
func (t Array) Float() float64 {
var item C.double
C.mlx_array_item_float64(&item, t.ctx)
return float64(item)
}
// Ints extracts all elements as int slice (from int32 data).
func (t Array) Ints() []int {
ints := make([]int, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
ints[i] = int(f)
}
return ints
}
// DataInt32 extracts all elements as int32 slice.
func (t Array) DataInt32() []int32 {
data := make([]int32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
data[i] = int32(f)
}
return data
}
// Floats extracts all elements as float32 slice.
func (t Array) Floats() []float32 {
floats := make([]float32, t.Size())
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
floats[i] = float32(f)
}
return floats
}
// Free releases arrays using reference-counted cleanup.
// Arrays with remaining references are not freed.
func Free(s ...*Array) int {
var n int
free := make([]*Array, 0, 64)
fn := func(t *Array) {
if t != nil && t.Valid() {
t.desc.numRefs--
if t.desc.numRefs <= 0 {
free = append(free, t.desc.inputs...)
n += t.NumBytes()
C.mlx_array_free(t.ctx)
t.ctx.ctx = nil
}
}
}
for _, t := range s {
fn(t)
}
for len(free) > 0 {
tail := free[len(free)-1]
free = free[:len(free)-1]
fn(tail)
}
return n
}

178
pkg/mlx/cache/cache.go vendored Normal file
View file

@ -0,0 +1,178 @@
//go:build darwin && arm64 && mlx
// Package cache provides KV cache implementations for transformer inference.
package cache
import "forge.lthn.ai/core/cli/pkg/mlx"
// Cache manages key-value pairs for transformer attention layers.
type Cache interface {
// Update adds new key/value tensors and returns the full cached K/V.
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
// Offset returns the total number of tokens processed.
Offset() int
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
Len() int
// State returns the cached K/V arrays, or nil if empty.
State() []*mlx.Array
// Reset clears the cache for a new generation session.
Reset()
}
// KVCache implements an unbounded cache that grows as needed.
// Pre-allocates in chunks of `step` tokens to reduce allocations.
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
// NewKVCache creates a new unbounded KV cache with 256-token chunks.
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed.
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
func (c *KVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
}
// RotatingKVCache implements a bounded sliding window cache.
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
// NewRotatingKVCache creates a cache bounded to maxSize tokens.
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
func (c *RotatingKVCache) Reset() {
c.keys = nil
c.values = nil
c.offset = 0
c.idx = 0
}

85
pkg/mlx/compile.go Normal file
View file

@ -0,0 +1,85 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
// Callback for compiled functions.
extern void goCompiledFunc(mlx_vector_array inputs, mlx_vector_array outputs, void *payload);
static mlx_closure new_closure(void *payload) {
return mlx_closure_new_func_payload(&goCompiledFunc, payload);
}
*/
import "C"
import (
"sync"
"unsafe"
)
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
type CompiledFunc struct {
fn func([]*Array) []*Array
closure C.mlx_closure
mu sync.Mutex
}
var compiledFuncs sync.Map
//export goCompiledFunc
func goCompiledFunc(inputs C.mlx_vector_array, outputs C.mlx_vector_array, payload unsafe.Pointer) {
id := uintptr(payload)
fnI, ok := compiledFuncs.Load(id)
if !ok {
return
}
fn := fnI.(func([]*Array) []*Array)
// Convert inputs
nInputs := int(C.mlx_vector_array_size(inputs))
goInputs := make([]*Array, nInputs)
for i := 0; i < nInputs; i++ {
a := New("INPUT")
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
goInputs[i] = a
}
// Call user function
goOutputs := fn(goInputs)
// Set outputs
for _, out := range goOutputs {
C.mlx_vector_array_append_value(outputs, out.ctx)
}
}
var nextID uintptr
var nextIDMu sync.Mutex
// CompileShapeless compiles a function for efficient repeated execution.
// The function must accept and return arrays of consistent shapes.
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
nextIDMu.Lock()
nextID++
id := nextID
nextIDMu.Unlock()
compiledFuncs.Store(id, fn)
cf := &CompiledFunc{fn: fn}
cf.closure = C.new_closure(unsafe.Pointer(id))
return cf
}
// Call executes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
cf.mu.Lock()
defer cf.mu.Unlock()
// Fall back to direct call — compilation is an optimization.
// The compiled closure can be used via mlx_compiled but the
// direct path is simpler and still benefits from MLX's lazy evaluation.
return cf.fn(inputs)
}

83
pkg/mlx/dtype.go Normal file
View file

@ -0,0 +1,83 @@
//go:build darwin && arm64 && mlx
package mlx
// #include "mlx/c/mlx.h"
import "C"
import "encoding/json"
// DType represents an MLX array data type.
type DType C.mlx_dtype
const (
DTypeBool DType = C.MLX_BOOL
DTypeUint8 DType = C.MLX_UINT8
DTypeUint16 DType = C.MLX_UINT16
DTypeUint32 DType = C.MLX_UINT32
DTypeUint64 DType = C.MLX_UINT64
DTypeInt8 DType = C.MLX_INT8
DTypeInt16 DType = C.MLX_INT16
DTypeInt32 DType = C.MLX_INT32
DTypeInt64 DType = C.MLX_INT64
DTypeFloat16 DType = C.MLX_FLOAT16
DTypeFloat32 DType = C.MLX_FLOAT32
DTypeFloat64 DType = C.MLX_FLOAT64
DTypeBFloat16 DType = C.MLX_BFLOAT16
DTypeComplex64 DType = C.MLX_COMPLEX64
)
var dtypeNames = map[DType]string{
DTypeBool: "bool",
DTypeUint8: "uint8",
DTypeUint16: "uint16",
DTypeUint32: "uint32",
DTypeUint64: "uint64",
DTypeInt8: "int8",
DTypeInt16: "int16",
DTypeInt32: "int32",
DTypeInt64: "int64",
DTypeFloat16: "float16",
DTypeFloat32: "float32",
DTypeFloat64: "float64",
DTypeBFloat16: "bfloat16",
DTypeComplex64: "complex64",
}
func (d DType) String() string {
if s, ok := dtypeNames[d]; ok {
return s
}
return "unknown"
}
var dtypeFromString = map[string]DType{
"bool": DTypeBool, "BOOL": DTypeBool,
"uint8": DTypeUint8, "U8": DTypeUint8,
"uint16": DTypeUint16, "U16": DTypeUint16,
"uint32": DTypeUint32, "U32": DTypeUint32,
"uint64": DTypeUint64, "U64": DTypeUint64,
"int8": DTypeInt8, "I8": DTypeInt8,
"int16": DTypeInt16, "I16": DTypeInt16,
"int32": DTypeInt32, "I32": DTypeInt32,
"int64": DTypeInt64, "I64": DTypeInt64,
"float16": DTypeFloat16, "F16": DTypeFloat16,
"float32": DTypeFloat32, "F32": DTypeFloat32,
"float64": DTypeFloat64, "F64": DTypeFloat64,
"bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16,
"complex64": DTypeComplex64,
}
// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc.
func (d *DType) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
if dt, ok := dtypeFromString[s]; ok {
*d = dt
return nil
}
*d = DTypeFloat32 // default
return nil
}

81
pkg/mlx/fast.go Normal file
View file

@ -0,0 +1,81 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import "unsafe"
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
func RMSNorm(x, weight *Array, eps float32) *Array {
out := New("FAST_RMSNORM", x)
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// LayerNorm applies Layer normalization using a fused Metal kernel.
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
out := New("FAST_LAYERNORM", x)
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
return out
}
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
freqs := New("")
out := New("FAST_ROPE", x, freqs)
C.mlx_fast_rope(
&out.ctx,
x.ctx,
C.int(dims),
C._Bool(traditional),
C.mlx_optional_float{
value: C.float(base),
has_value: C._Bool(base != 0),
},
C.float(scale),
C.int(offset),
freqs.ctx,
DefaultStream().ctx,
)
return out
}
// ScaledDotProductAttention computes attention using a fused Metal kernel.
// mask can be nil for causal masking, or set causal=true for auto causal mask.
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
var mask, sinks *Array
if causal {
mask = New("")
sinks = New("")
} else {
mask = New("")
sinks = New("")
}
mode := "causal"
if !causal {
mode = "none"
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
sinks := New("")
cMode := C.CString("none")
defer C.free(unsafe.Pointer(cMode))
out := New("FAST_SDPA", query, key, value, mask, sinks)
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
return out
}

60
pkg/mlx/io.go Normal file
View file

@ -0,0 +1,60 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
import (
"iter"
"unsafe"
)
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
Init()
return func(yield func(string, *Array) bool) {
string2array := C.mlx_map_string_to_array_new()
defer C.mlx_map_string_to_array_free(string2array)
string2string := C.mlx_map_string_to_string_new()
defer C.mlx_map_string_to_string_free(string2string)
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
cpu := C.mlx_default_cpu_stream_new()
defer C.mlx_stream_free(cpu)
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
it := C.mlx_map_string_to_array_iterator_new(string2array)
defer C.mlx_map_string_to_array_iterator_free(it)
for {
var key *C.char
value := C.mlx_array_new()
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
break
}
name := C.GoString(key)
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
break
}
}
}
}
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
func LoadAllSafetensors(path string) map[string]*Array {
tensors := make(map[string]*Array)
for name, arr := range LoadSafetensors(path) {
tensors[name] = arr
}
return tensors
}

103
pkg/mlx/mlx.go Normal file
View file

@ -0,0 +1,103 @@
//go:build darwin && arm64 && mlx
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
//
// Build mlx-c before use:
//
// cd pkg/mlx && go generate ./...
//
// Build with MLX enabled:
//
// go build -tags mlx -o core .
package mlx
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
//go:generate cmake --build build --parallel
//go:generate cmake --install build
/*
#cgo CXXFLAGS: -std=c++17
#cgo CPPFLAGS: -I${SRCDIR}/dist/include
#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx -lstdc++
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
#include <stdlib.h>
#include "mlx/c/mlx.h"
extern void goMLXErrorHandler(const char *msg, void *data);
static void set_error_handler() {
mlx_set_error_handler(&goMLXErrorHandler, NULL, NULL);
}
*/
import "C"
import (
"log/slog"
"sync"
"unsafe"
)
var initOnce sync.Once
// Init sets up the MLX error handler. Called automatically on first use.
func Init() {
initOnce.Do(func() {
C.set_error_handler()
slog.Debug("mlx: initialized with Metal backend")
})
}
//export goMLXErrorHandler
func goMLXErrorHandler(msg *C.char, data unsafe.Pointer) {
slog.Error("mlx", "error", C.GoString(msg))
}
// Materialize synchronously evaluates arrays, computing their values on the GPU.
// This is the MLX equivalent of forcing lazy computation to complete.
func Materialize(outputs ...*Array) {
doMaterialize(outputs, false)
}
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
func MaterializeAsync(outputs ...*Array) {
doMaterialize(outputs, true)
}
func doMaterialize(outputs []*Array, async bool) {
Init()
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
for _, output := range outputs {
if output != nil && output.Valid() {
C.mlx_vector_array_append_value(vector, output.ctx)
}
}
if async {
C.mlx_async_eval(vector)
} else {
C.mlx_eval(vector)
}
}
// Collect gathers all valid arrays from a variadic list for batch Materialize.
func Collect(arrays ...*Array) []*Array {
var out []*Array
for _, a := range arrays {
if a != nil && a.Valid() {
out = append(out, a)
}
}
return out
}
// MetalAvailable reports whether Metal GPU is available.
func MetalAvailable() bool {
Init()
var available C.bool
C.mlx_metal_is_available(&available)
return bool(available)
}

10
pkg/mlx/mlx_stub.go Normal file
View file

@ -0,0 +1,10 @@
//go:build !(darwin && arm64 && mlx)
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
// This stub file is used on non-darwin/non-arm64 platforms or when the
// mlx build tag is not set. All operations report MLX as unavailable.
package mlx
// MetalAvailable reports whether Metal GPU is available.
// Always returns false on non-Apple Silicon platforms.
func MetalAvailable() bool { return false }

327
pkg/mlx/model/gemma3.go Normal file
View file

@ -0,0 +1,327 @@
//go:build darwin && arm64 && mlx
// Package model provides transformer model architectures for MLX inference.
package model
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"forge.lthn.ai/core/cli/pkg/mlx"
"forge.lthn.ai/core/cli/pkg/mlx/cache"
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
)
// TextConfig holds Gemma 3 text model configuration.
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
}
// GemmaModel is the Gemma 3 text model.
type GemmaModel struct {
EmbedTokens *mlx.Embedding
Layers []*DecoderLayer
Norm *mlx.RMSNormModule
Output *mlx.Linear // Tied to EmbedTokens
// Precomputed (1 + weight) for Gemma-style RMSNorm
NormScaled *mlx.Array
Tok *tokenizer.Tokenizer
Cfg *TextConfig
}
// DecoderLayer is a single transformer block.
type DecoderLayer struct {
InputNorm *mlx.RMSNormModule
Attention *Attention
PostAttnNorm *mlx.RMSNormModule
PreFFNorm *mlx.RMSNormModule
MLP *MLP
PostFFNorm *mlx.RMSNormModule
// Precomputed scaled weights
InputNormScaled *mlx.Array
PostAttnNormScaled *mlx.Array
PreFFNormScaled *mlx.Array
PostFFNormScaled *mlx.Array
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization.
type Attention struct {
QProj *mlx.Linear
KProj *mlx.Linear
VProj *mlx.Linear
OProj *mlx.Linear
QNorm *mlx.RMSNormModule
KNorm *mlx.RMSNormModule
QNormScaled *mlx.Array
KNormScaled *mlx.Array
}
// MLP is the feed-forward network.
type MLP struct {
GateProj *mlx.Linear
UpProj *mlx.Linear
DownProj *mlx.Linear
}
// compiledGELU is a singleton for the compiled GELU function.
var compiledGELU *mlx.CompiledFunc
func getCompiledGELU() *mlx.CompiledFunc {
if compiledGELU == nil {
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApprox(inputs[0])}
}, true)
}
return compiledGELU
}
// geluApprox computes GELU using the tanh approximation:
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApprox(x *mlx.Array) *mlx.Array {
const sqrt2OverPi = 0.7978845608028654
const coeff = 0.044715
x3 := mlx.Mul(mlx.Mul(x, x), x)
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
scaled := mlx.MulScalar(inner, sqrt2OverPi)
t := mlx.Tanh(scaled)
onePlusT := mlx.AddScalar(t, 1.0)
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
}
// LoadGemma3 loads a Gemma 3 text model from a directory.
func LoadGemma3(modelPath string) (*GemmaModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load config: %w", err)
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("gemma3: parse config: %w", err)
}
// Defaults
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
if cfg.SlidingWindowPattern == 0 {
cfg.SlidingWindowPattern = 6
}
// Load tokenizer
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
}
// Load weights from all safetensors files
weights := make(map[string]*mlx.Array)
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
for _, path := range matches {
for name, arr := range mlx.LoadSafetensors(path) {
weights[name] = arr
}
}
m := &GemmaModel{
EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]},
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]},
Tok: tok,
Cfg: &cfg,
}
// Initialize layers
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
m.Layers[i] = &DecoderLayer{
InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]},
PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]},
PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]},
PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]},
Attention: &Attention{
QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil),
KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil),
VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil),
OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil),
QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]},
KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]},
},
MLP: &MLP{
GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil),
UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil),
DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil),
},
LayerIdx: i,
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
}
}
// Tied embeddings
m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil)
// Materialize all weights
var allArrays []*mlx.Array
for _, a := range weights {
allArrays = append(allArrays, a)
}
mlx.Materialize(allArrays...)
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeScaledWeights(m)
return m, nil
}
func precomputeScaledWeights(m *GemmaModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Materialize(scaled...)
}
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false
}
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass.
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
shape := tokens.Shape()
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, m.Cfg)
}
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
}
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// GQA: repeat K/V heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = mlx.RepeatKV(k, repeatFactor)
v = mlx.RepeatKV(v, repeatFactor)
}
// Scaled dot-product attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// NewCache creates per-layer caches for generation.
func (m *GemmaModel) NewCache() []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
// NumLayers returns the number of transformer layers.
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
// Tokenizer returns the model's tokenizer.
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }

59
pkg/mlx/nn.go Normal file
View file

@ -0,0 +1,59 @@
//go:build darwin && arm64 && mlx
package mlx
// Linear is a fully-connected layer: y = x @ W.T + bias.
type Linear struct {
Weight *Array `weight:"weight"`
Bias *Array `weight:"bias"`
}
// NewLinear creates a Linear layer with optional bias.
func NewLinear(weight, bias *Array) *Linear {
return &Linear{Weight: weight, Bias: bias}
}
// Forward computes the linear transformation.
func (l *Linear) Forward(x *Array) *Array {
out := Matmul(x, Transpose(l.Weight))
if l.Bias != nil && l.Bias.Valid() {
out = Add(out, l.Bias)
}
return out
}
// Embedding is a lookup table for token embeddings.
type Embedding struct {
Weight *Array `weight:"weight"`
}
// Forward looks up embeddings for the given token indices.
func (e *Embedding) Forward(indices *Array) *Array {
return Take(e.Weight, indices, 0)
}
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
type RMSNormModule struct {
Weight *Array `weight:"weight"`
}
// Forward applies RMS normalization.
func (r *RMSNormModule) Forward(x *Array, eps float32) *Array {
return RMSNorm(x, r.Weight, eps)
}
// RepeatKV repeats key/value heads for grouped-query attention.
// Input shape: [B, num_kv_heads, L, D]
// Output shape: [B, num_kv_heads * factor, L, D]
func RepeatKV(x *Array, factor int32) *Array {
if factor <= 1 {
return x
}
shape := x.Shape()
B, H, L, D := shape[0], shape[1], shape[2], shape[3]
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
expanded := ExpandDims(x, 2)
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
return Reshape(expanded, B, H*factor, L, D)
}

308
pkg/mlx/ops.go Normal file
View file

@ -0,0 +1,308 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include <stdlib.h>
#include "mlx/c/mlx.h"
*/
import "C"
// --- Element-wise arithmetic ---
// Add returns element-wise a + b.
func Add(a, b *Array) *Array {
out := New("ADD", a, b)
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// AddScalar returns a + scalar (broadcast).
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Add(a, scalar)
}
// Mul returns element-wise a * b.
func Mul(a, b *Array) *Array {
out := New("MUL", a, b)
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// MulScalar returns a * scalar (broadcast).
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Mul(a, scalar)
}
// Divide returns element-wise a / b.
func Divide(a, b *Array) *Array {
out := New("DIV", a, b)
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Subtract returns element-wise a - b.
func Subtract(a, b *Array) *Array {
out := New("SUB", a, b)
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Negative returns element-wise -a.
func Negative(a *Array) *Array {
out := New("NEG", a)
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// --- Math functions ---
// Exp returns element-wise exp(a).
func Exp(a *Array) *Array {
out := New("EXP", a)
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Tanh returns element-wise tanh(a).
func Tanh(a *Array) *Array {
out := New("TANH", a)
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Sqrt returns element-wise sqrt(a).
func Sqrt(a *Array) *Array {
out := New("SQRT", a)
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Rsqrt returns element-wise 1/sqrt(a).
func Rsqrt(a *Array) *Array {
out := New("RSQRT", a)
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Reciprocal returns element-wise 1/a.
func Reciprocal(a *Array) *Array {
out := New("RECIPROCAL", a)
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Square returns element-wise a^2.
func Square(a *Array) *Array {
out := New("SQUARE", a)
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// Power returns element-wise a^b.
func Power(a, b *Array) *Array {
out := New("POWER", a, b)
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Maximum returns element-wise max(a, b).
func Maximum(a, b *Array) *Array {
out := New("MAX", a, b)
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Minimum returns element-wise min(a, b).
func Minimum(a, b *Array) *Array {
out := New("MIN", a, b)
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// --- Matrix operations ---
// Matmul returns the matrix product of a and b.
func Matmul(a, b *Array) *Array {
out := New("MATMUL", a, b)
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// QuantizedMatmul performs quantized matrix multiplication.
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
out := New("QMATMUL", x, w, scales, biases)
C.mlx_quantized_matmul(
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
C._Bool(transpose), C.int(groupSize), C.int(bits),
DefaultStream().ctx,
)
return out
}
// --- Reductions ---
// Softmax returns softmax along the last axis.
func Softmax(a *Array) *Array {
out := New("SOFTMAX", a)
axis := []C.int{C.int(-1)}
C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx)
return out
}
// Argmax returns the index of the maximum value along an axis.
func Argmax(a *Array, axis int, keepDims bool) *Array {
out := New("ARGMAX", a)
C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// TopK returns the top k values along the last axis.
func TopK(a *Array, k int) *Array {
out := New("TOPK", a)
C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
return out
}
// Sum reduces by summation along the given axis.
func Sum(a *Array, axis int, keepDims bool) *Array {
out := New("SUM", a)
axes := []C.int{C.int(axis)}
C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// Mean reduces by averaging along the given axis.
func Mean(a *Array, axis int, keepDims bool) *Array {
out := New("MEAN", a)
axes := []C.int{C.int(axis)}
C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
return out
}
// --- Shape operations ---
// Reshape changes the shape of an array.
func Reshape(a *Array, shape ...int32) *Array {
out := New("RESHAPE", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
return out
}
// Transpose permutes dimensions. If no axes given, reverses all dims.
func Transpose(a *Array, axes ...int) *Array {
out := New("TRANSPOSE", a)
if len(axes) == 0 {
C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx)
} else {
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
}
return out
}
// ExpandDims inserts a new axis at the given position.
func ExpandDims(a *Array, axis int) *Array {
out := New("EXPAND_DIMS", a)
axes := []C.int{C.int(axis)}
C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx)
return out
}
// Squeeze removes dimensions of size 1.
func Squeeze(a *Array, axes ...int) *Array {
out := New("SQUEEZE", a)
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
return out
}
// Concatenate joins arrays along the given axis.
func Concatenate(arrays []*Array, axis int) *Array {
vector := C.mlx_vector_array_new()
defer C.mlx_vector_array_free(vector)
inputs := make([]*Array, len(arrays))
for i, a := range arrays {
C.mlx_vector_array_append_value(vector, a.ctx)
inputs[i] = a
}
out := New("CONCAT", inputs...)
C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
return out
}
// BroadcastTo broadcasts an array to the given shape.
func BroadcastTo(a *Array, shape []int32) *Array {
out := New("BROADCAST", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
return out
}
// AsType casts an array to a different dtype.
func AsType(a *Array, dtype DType) *Array {
out := New("ASTYPE", a)
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
return out
}
// AsStrided creates a view with custom strides.
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
out := New("AS_STRIDED", a)
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.size_t, len(strides))
for i, s := range strides {
cStrides[i] = C.size_t(s)
}
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
return out
}
// Take gathers elements from a along axis using indices.
func Take(a, indices *Array, axis int) *Array {
out := New("TAKE", a, indices)
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
return out
}
// Where selects elements from a or b based on condition.
func Where(condition, a, b *Array) *Array {
out := New("WHERE", condition, a, b)
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
return out
}
// Argpartition partially sorts and returns indices for top-k selection.
func Argpartition(a *Array, kth, axis int) *Array {
out := New("ARGPARTITION", a)
C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
return out
}
// PutAlongAxis places values into array at indices along axis.
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
out := New("PUT_ALONG_AXIS", a, indices, values)
// Use scatter approach: src[indices] = values
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
return out
}

44
pkg/mlx/random.go Normal file
View file

@ -0,0 +1,44 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// RandomCategorical samples from a categorical distribution defined by logprobs.
// Returns indices sampled according to the log-probability distribution along the last axis.
func RandomCategorical(logprobs *Array) *Array {
out := New("RANDOM_CATEGORICAL", logprobs)
// shape for output: same as input but last dim removed
C.mlx_random_categorical_shape(
&out.ctx,
logprobs.ctx,
C.int(-1), // axis
nil, C.int(0), // empty shape = infer from input
nil, // key (use default)
DefaultStream().ctx,
)
return out
}
// RandomUniform generates uniform random values in [low, high).
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
out := New("RANDOM_UNIFORM")
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
lo := FromValue(low)
hi := FromValue(high)
C.mlx_random_uniform(
&out.ctx,
lo.ctx, hi.ctx,
&cShape[0], C.int(len(cShape)),
C.mlx_dtype(dtype),
nil, // key
DefaultStream().ctx,
)
return out
}

105
pkg/mlx/sample/sample.go Normal file
View file

@ -0,0 +1,105 @@
//go:build darwin && arm64 && mlx
// Package sample provides composable token sampling strategies.
package sample
import (
"math"
"forge.lthn.ai/core/cli/pkg/mlx"
)
// Sampler transforms logits into a sampled token index.
type Sampler interface {
Sample(logits *mlx.Array) *mlx.Array
}
// New creates a composable sampler chain from the given parameters.
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
func New(temp, topP, minP float32, topK int) Sampler {
if temp == 0 {
return greedy{}
}
var samplers []Sampler
if topP > 0 && topP < 1 {
samplers = append(samplers, TopP(topP))
}
if minP > 0 {
samplers = append(samplers, MinPSampler(minP))
}
if topK > 0 {
samplers = append(samplers, TopKSampler(topK))
}
samplers = append(samplers, Temperature(temp))
return chain(samplers)
}
// chain applies a sequence of samplers, then samples from the result.
type chain []Sampler
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
for _, s := range c {
logits = s.Sample(logits)
}
// Final categorical sample from log-probabilities
return mlx.RandomCategorical(logits)
}
// greedy returns the argmax token.
type greedy struct{}
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
return mlx.Argmax(logits, -1, false)
}
// Temperature scales logits by 1/temp.
type Temperature float32
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
return mlx.MulScalar(logits, 1.0/float32(t))
}
// TopKSampler masks all but the top-k logits.
type TopKSampler int
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
neg := mlx.Negative(logits)
mask := mlx.Argpartition(neg, int(k)-1, -1)
// Slice the indices beyond top-k
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
}
// TopP implements nucleus sampling (cumulative probability threshold).
type TopP float32
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
// Softmax to get probabilities
probs := mlx.Softmax(logits)
// Sort descending
neg := mlx.Negative(probs)
sortedIdx := mlx.Argpartition(neg, 0, -1)
sortedProbs := mlx.Take(probs, sortedIdx, -1)
// Cumulative sum
cumProbs := mlx.Sum(sortedProbs, -1, true) // simplified — full impl needs cumsum
// Mask tokens beyond threshold
threshold := mlx.FromValue(float32(p))
mask := mlx.Where(
mlx.FromValue(true), // placeholder — proper impl compares cumprobs > p
mlx.FromValue(float32(math.Inf(-1))),
logits,
)
return mask
}
// MinPSampler masks tokens below min_p * max_prob.
type MinPSampler float32
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
// For now, pass through — MinP is an optimization over TopP.
// Full implementation requires finding max prob and masking below threshold.
return logits
}

63
pkg/mlx/slice.go Normal file
View file

@ -0,0 +1,63 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
// Slice extracts a sub-array using start and end indices for each dimension.
// starts and ends must have the same length as the array's dimensions.
func Slice(a *Array, starts, ends []int32) *Array {
out := New("SLICE", a)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
return out
}
// SliceAxis extracts a sub-array along a single axis.
func SliceAxis(a *Array, axis int, start, end int32) *Array {
// Build full slice parameters
ndim := a.NumDims()
starts := make([]int32, ndim)
ends := make([]int32, ndim)
for i := 0; i < ndim; i++ {
starts[i] = 0
ends[i] = int32(a.Dim(i))
}
ax := axis
if ax < 0 {
ax = ndim + ax
}
starts[ax] = start
ends[ax] = end
return Slice(a, starts, ends)
}
// SliceUpdateInplace updates a slice of the array in-place.
// This is critical for KV cache updates.
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
out := New("SLICE_UPDATE", a, update)
cStarts := make([]C.int, len(starts))
cEnds := make([]C.int, len(ends))
for i := range starts {
cStarts[i] = C.int(starts[i])
cEnds[i] = C.int(ends[i])
}
strides := make([]C.int, len(starts))
for i := range strides {
strides[i] = 1
}
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
return out
}

74
pkg/mlx/stream.go Normal file
View file

@ -0,0 +1,74 @@
//go:build darwin && arm64 && mlx
package mlx
/*
#include "mlx/c/mlx.h"
*/
import "C"
import "sync"
// Stream wraps an mlx_stream handle for dispatching operations.
type Stream struct {
ctx C.mlx_stream
}
var (
defaultStream *Stream
defaultStreamOnce sync.Once
)
// DefaultStream returns the default GPU stream, creating it on first use.
func DefaultStream() *Stream {
defaultStreamOnce.Do(func() {
Init()
defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
})
return defaultStream
}
// DefaultGPUStream returns a new GPU stream.
func DefaultGPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
}
// DefaultCPUStream returns a new CPU stream.
func DefaultCPUStream() *Stream {
Init()
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
}
// Synchronize waits for all operations on the stream to complete.
func Synchronize(s *Stream) {
C.mlx_synchronize(s.ctx)
}
// SetMemoryLimit sets the Metal memory limit. Returns the previous limit.
func SetMemoryLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_memory_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// SetCacheLimit sets the Metal cache limit. Returns the previous limit.
func SetCacheLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_cache_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// GetActiveMemory returns the current Metal memory usage in bytes.
func GetActiveMemory() uint64 {
var mem C.size_t
C.mlx_get_active_memory(&mem)
return uint64(mem)
}
// GetPeakMemory returns the peak Metal memory usage in bytes.
func GetPeakMemory() uint64 {
var mem C.size_t
C.mlx_get_peak_memory(&mem)
return uint64(mem)
}

View file

@ -0,0 +1,174 @@
//go:build darwin && arm64 && mlx
// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models.
package tokenizer
import (
"encoding/json"
"fmt"
"os"
"strings"
)
// Tokenizer handles text-to-token and token-to-text conversion.
type Tokenizer struct {
vocab map[string]int32
invVocab map[int32]string
merges []mergePair
special map[string]int32
bosToken int32
eosToken int32
}
type mergePair struct {
a, b string
rank int
}
// tokenizerJSON is the HuggingFace tokenizer.json format.
type tokenizerJSON struct {
Model struct {
Type string `json:"type"`
Vocab json.RawMessage `json:"vocab"`
Merges []string `json:"merges"`
ByteFallback bool `json:"byte_fallback"`
} `json:"model"`
AddedTokens []struct {
ID int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
}
// Load reads a tokenizer.json file and creates a Tokenizer.
func Load(path string) (*Tokenizer, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
}
var tj tokenizerJSON
if err := json.Unmarshal(data, &tj); err != nil {
return nil, fmt.Errorf("tokenizer: parse: %w", err)
}
t := &Tokenizer{
vocab: make(map[string]int32),
invVocab: make(map[int32]string),
special: make(map[string]int32),
}
// Parse vocab
var vocab map[string]int32
if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil {
return nil, fmt.Errorf("tokenizer: parse vocab: %w", err)
}
t.vocab = vocab
for k, v := range vocab {
t.invVocab[v] = k
}
// Parse merges
for rank, merge := range tj.Model.Merges {
parts := strings.SplitN(merge, " ", 2)
if len(parts) == 2 {
t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank})
}
}
// Parse special tokens
for _, tok := range tj.AddedTokens {
if tok.Special {
t.special[tok.Content] = tok.ID
}
t.vocab[tok.Content] = tok.ID
t.invVocab[tok.ID] = tok.Content
}
// Set BOS/EOS
if id, ok := t.special["<bos>"]; ok {
t.bosToken = id
}
if id, ok := t.special["<eos>"]; ok {
t.eosToken = id
}
if id, ok := t.special["<end_of_turn>"]; ok {
t.eosToken = id // Gemma uses end_of_turn as EOS
}
return t, nil
}
// Encode converts text to token IDs. Prepends BOS token.
func (t *Tokenizer) Encode(text string) []int32 {
tokens := []int32{t.bosToken}
// Simple BPE encoding — split into characters then merge
// This is a simplified version. Full implementation handles
// Unicode, byte fallback, and efficient BPE merging.
chars := []string{}
for _, r := range text {
s := string(r)
if s == " " {
s = "▁" // SentencePiece space marker
}
chars = append(chars, s)
}
// Check for special tokens first
remaining := text
for remaining != "" {
found := false
for tok, id := range t.special {
if strings.HasPrefix(remaining, tok) {
tokens = append(tokens, id)
remaining = remaining[len(tok):]
found = true
break
}
}
if !found {
// Encode character by character (simplified BPE)
r := []rune(remaining)
ch := "▁" + string(r[0])
if id, ok := t.vocab[ch]; ok {
tokens = append(tokens, id)
} else if id, ok := t.vocab[string(r[0])]; ok {
tokens = append(tokens, id)
}
remaining = string(r[1:])
}
}
return tokens
}
// Decode converts token IDs back to text.
func (t *Tokenizer) Decode(tokens []int32) string {
var sb strings.Builder
for _, id := range tokens {
if text, ok := t.invVocab[id]; ok {
// Replace SentencePiece space marker
text = strings.ReplaceAll(text, "▁", " ")
sb.WriteString(text)
}
}
result := sb.String()
// Trim leading space from SentencePiece encoding
if strings.HasPrefix(result, " ") {
result = result[1:]
}
return result
}
// BOSToken returns the beginning-of-sequence token ID.
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
// EOSToken returns the end-of-sequence token ID.
func (t *Tokenizer) EOSToken() int32 { return t.eosToken }
// FormatGemmaPrompt applies the Gemma 3 chat template.
func FormatGemmaPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}