feat: add native MLX backend for Apple Silicon inference (pkg/mlx)
CGo wrapper for mlx-c providing zero-Python Metal GPU inference. Includes Gemma 3 model architecture, BPE tokenizer, KV cache, composable sampling, and OpenAI-compatible serve command. Build-tagged (darwin && arm64 && mlx) with stubs for cross-platform. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ca8c155d85
commit
9d664c055a
20 changed files with 2398 additions and 0 deletions
|
|
@ -10,6 +10,7 @@
|
|||
// - core ml convert: Convert MLX LoRA adapter to PEFT format
|
||||
// - core ml agent: Run the scoring agent daemon
|
||||
// - core ml worker: Run a distributed worker node
|
||||
// - core ml serve: Start OpenAI-compatible inference server
|
||||
package ml
|
||||
|
||||
import (
|
||||
|
|
@ -38,6 +39,7 @@ func AddMLCommands(root *cli.Command) {
|
|||
mlCmd.AddCommand(convertCmd)
|
||||
mlCmd.AddCommand(agentCmd)
|
||||
mlCmd.AddCommand(workerCmd)
|
||||
mlCmd.AddCommand(serveCmd)
|
||||
root.AddCommand(mlCmd)
|
||||
}
|
||||
|
||||
|
|
|
|||
174
internal/cmd/ml/cmd_serve.go
Normal file
174
internal/cmd/ml/cmd_serve.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/cli"
|
||||
"forge.lthn.ai/core/cli/pkg/ml"
|
||||
)
|
||||
|
||||
var serveCmd = &cli.Command{
|
||||
Use: "serve",
|
||||
Short: "Start OpenAI-compatible inference server",
|
||||
Long: "Starts an HTTP server serving /v1/completions and /v1/chat/completions using the configured ML backend.",
|
||||
RunE: runServe,
|
||||
}
|
||||
|
||||
var (
|
||||
serveBind string
|
||||
serveModelPath string
|
||||
)
|
||||
|
||||
func init() {
|
||||
serveCmd.Flags().StringVar(&serveBind, "bind", "0.0.0.0:8090", "Address to bind")
|
||||
serveCmd.Flags().StringVar(&serveModelPath, "model-path", "", "Path to model directory (for mlx backend)")
|
||||
}
|
||||
|
||||
type completionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type completionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []completionChoice `json:"choices"`
|
||||
Usage usageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
type completionChoice struct {
|
||||
Text string `json:"text"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type chatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ml.Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type chatResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []chatChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type chatChoice struct {
|
||||
Message ml.Message `json:"message"`
|
||||
Index int `json:"index"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type usageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
func runServe(cmd *cli.Command, args []string) error {
|
||||
// Create a backend — use HTTP backend pointing to configured API URL.
|
||||
// On macOS with MLX build tag, this will use the native MLX backend instead.
|
||||
backend := ml.NewHTTPBackend(apiURL, modelName)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req completionRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), 400)
|
||||
return
|
||||
}
|
||||
|
||||
opts := ml.GenOpts{
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Model: req.Model,
|
||||
}
|
||||
|
||||
text, err := backend.Generate(r.Context(), req.Prompt, opts)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
resp := completionResponse{
|
||||
ID: fmt.Sprintf("cmpl-%d", time.Now().UnixNano()),
|
||||
Object: "text_completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: backend.Name(),
|
||||
Choices: []completionChoice{{Text: text, FinishReason: "stop"}},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
var req chatRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
http.Error(w, err.Error(), 400)
|
||||
return
|
||||
}
|
||||
|
||||
opts := ml.GenOpts{
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
Model: req.Model,
|
||||
}
|
||||
|
||||
text, err := backend.Chat(r.Context(), req.Messages, opts)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
resp := chatResponse{
|
||||
ID: fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()),
|
||||
Object: "chat.completion",
|
||||
Created: time.Now().Unix(),
|
||||
Model: backend.Name(),
|
||||
Choices: []chatChoice{{
|
||||
Message: ml.Message{Role: "assistant", Content: text},
|
||||
FinishReason: "stop",
|
||||
}},
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
mux.HandleFunc("GET /v1/models", func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := struct {
|
||||
Object string `json:"object"`
|
||||
Data []struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}{
|
||||
Object: "list",
|
||||
Data: []struct {
|
||||
ID string `json:"id"`
|
||||
}{{ID: backend.Name()}},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
slog.Info("ml serve: starting", "bind", serveBind, "backend", backend.Name())
|
||||
fmt.Printf("Serving on http://%s\n", serveBind)
|
||||
return http.ListenAndServe(serveBind, mux)
|
||||
}
|
||||
169
pkg/ml/backend_mlx.go
Normal file
169
pkg/ml/backend_mlx.go
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/mlx"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/cache"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/model"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/sample"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// MLXBackend implements Backend for native Metal inference via mlx-c.
|
||||
type MLXBackend struct {
|
||||
model *model.GemmaModel
|
||||
tok *tokenizer.Tokenizer
|
||||
caches []cache.Cache
|
||||
sampler sample.Sampler
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewMLXBackend loads a model from a safetensors directory and creates
|
||||
// a native Metal inference backend.
|
||||
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||
if !mlx.MetalAvailable() {
|
||||
return nil, fmt.Errorf("mlx: Metal GPU not available")
|
||||
}
|
||||
|
||||
slog.Info("mlx: loading model", "path", modelPath)
|
||||
m, err := model.LoadGemma3(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("mlx: model loaded",
|
||||
"layers", m.NumLayers(),
|
||||
"memory_mb", mlx.GetActiveMemory()/1024/1024,
|
||||
)
|
||||
|
||||
return &MLXBackend{
|
||||
model: m,
|
||||
tok: m.Tokenizer(),
|
||||
caches: m.NewCache(),
|
||||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate produces text from a prompt using native Metal inference.
|
||||
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Reset caches for new generation
|
||||
for _, c := range b.caches {
|
||||
c.Reset()
|
||||
}
|
||||
|
||||
// Set up sampler based on opts
|
||||
temp := float32(opts.Temperature)
|
||||
if temp == 0 {
|
||||
temp = 0.1
|
||||
}
|
||||
sampler := sample.New(temp, 0, 0, 0)
|
||||
|
||||
// Tokenize
|
||||
formatted := tokenizer.FormatGemmaPrompt(prompt)
|
||||
tokens := b.tok.Encode(formatted)
|
||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||||
|
||||
maxTokens := opts.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 2048
|
||||
}
|
||||
|
||||
// Generation loop
|
||||
var output []int32
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return b.tok.Decode(output), ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logits := b.model.Forward(input, b.caches)
|
||||
next := sampler.Sample(logits)
|
||||
mlx.Materialize(next)
|
||||
|
||||
nextToken := int32(next.Int())
|
||||
if nextToken == b.tok.EOSToken() {
|
||||
break
|
||||
}
|
||||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
}
|
||||
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
// Chat formats messages and generates a response.
|
||||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||||
// Format as Gemma chat
|
||||
var prompt string
|
||||
for _, msg := range messages {
|
||||
switch msg.Role {
|
||||
case "user":
|
||||
prompt += fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content)
|
||||
case "assistant":
|
||||
prompt += fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content)
|
||||
case "system":
|
||||
prompt += fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content)
|
||||
}
|
||||
}
|
||||
prompt += "<start_of_turn>model\n"
|
||||
|
||||
// Use raw prompt (already formatted)
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
for _, c := range b.caches {
|
||||
c.Reset()
|
||||
}
|
||||
|
||||
temp := float32(opts.Temperature)
|
||||
if temp == 0 {
|
||||
temp = 0.1
|
||||
}
|
||||
sampler := sample.New(temp, 0, 0, 0)
|
||||
|
||||
tokens := b.tok.Encode(prompt)
|
||||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||||
|
||||
maxTokens := opts.MaxTokens
|
||||
if maxTokens == 0 {
|
||||
maxTokens = 2048
|
||||
}
|
||||
|
||||
var output []int32
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return b.tok.Decode(output), ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logits := b.model.Forward(input, b.caches)
|
||||
next := sampler.Sample(logits)
|
||||
mlx.Materialize(next)
|
||||
|
||||
nextToken := int32(next.Int())
|
||||
if nextToken == b.tok.EOSToken() {
|
||||
break
|
||||
}
|
||||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
}
|
||||
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
||||
// Name returns the backend identifier.
|
||||
func (b *MLXBackend) Name() string { return "mlx" }
|
||||
|
||||
// Available reports whether Metal GPU is ready.
|
||||
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|
||||
26
pkg/mlx/CMakeLists.txt
Normal file
26
pkg/mlx/CMakeLists.txt
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
cmake_minimum_required(VERSION 3.5)
|
||||
|
||||
project(mlx)
|
||||
|
||||
if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
|
||||
endif()
|
||||
|
||||
set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
|
||||
set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
|
||||
set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
set(CMAKE_INSTALL_RPATH "@loader_path")
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG ${MLX_C_GIT_TAG}
|
||||
)
|
||||
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
273
pkg/mlx/array.go
Normal file
273
pkg/mlx/array.go
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type tensorDesc struct {
|
||||
name string
|
||||
inputs []*Array
|
||||
numRefs int
|
||||
}
|
||||
|
||||
// Array wraps an mlx_array handle with reference-counted memory management.
|
||||
type Array struct {
|
||||
ctx C.mlx_array
|
||||
desc tensorDesc
|
||||
}
|
||||
|
||||
// New creates a named Array tracking its input dependencies for cleanup.
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
name: name,
|
||||
inputs: inputs,
|
||||
},
|
||||
}
|
||||
for _, input := range inputs {
|
||||
if input != nil {
|
||||
input.desc.numRefs++
|
||||
}
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
||||
// FromValue creates a scalar Array from a Go value.
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
Init()
|
||||
tt := New("")
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
case int:
|
||||
tt.ctx = C.mlx_array_new_int(C.int(v))
|
||||
case float32:
|
||||
tt.ctx = C.mlx_array_new_float32(C.float(v))
|
||||
case float64:
|
||||
tt.ctx = C.mlx_array_new_float64(C.double(v))
|
||||
case complex64:
|
||||
tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
|
||||
default:
|
||||
panic("mlx: unsupported scalar type")
|
||||
}
|
||||
return tt
|
||||
}
|
||||
|
||||
type arrayTypes interface {
|
||||
~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~float32 | ~float64 |
|
||||
~complex64
|
||||
}
|
||||
|
||||
// FromValues creates an Array from a Go slice with the given shape.
|
||||
func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
|
||||
Init()
|
||||
if len(shape) == 0 {
|
||||
panic("mlx: shape required for non-scalar tensors")
|
||||
}
|
||||
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i := range shape {
|
||||
cShape[i] = C.int(shape[i])
|
||||
}
|
||||
|
||||
var dtype DType
|
||||
switch reflect.TypeOf(s).Elem().Kind() {
|
||||
case reflect.Bool:
|
||||
dtype = DTypeBool
|
||||
case reflect.Uint8:
|
||||
dtype = DTypeUint8
|
||||
case reflect.Uint16:
|
||||
dtype = DTypeUint16
|
||||
case reflect.Uint32:
|
||||
dtype = DTypeUint32
|
||||
case reflect.Uint64:
|
||||
dtype = DTypeUint64
|
||||
case reflect.Int8:
|
||||
dtype = DTypeInt8
|
||||
case reflect.Int16:
|
||||
dtype = DTypeInt16
|
||||
case reflect.Int32:
|
||||
dtype = DTypeInt32
|
||||
case reflect.Int64:
|
||||
dtype = DTypeInt64
|
||||
case reflect.Float32:
|
||||
dtype = DTypeFloat32
|
||||
case reflect.Float64:
|
||||
dtype = DTypeFloat64
|
||||
case reflect.Complex64:
|
||||
dtype = DTypeComplex64
|
||||
default:
|
||||
panic("mlx: unsupported element type")
|
||||
}
|
||||
|
||||
bts := make([]byte, binary.Size(s))
|
||||
if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
tt := New("")
|
||||
tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
|
||||
return tt
|
||||
}
|
||||
|
||||
// Zeros creates a zero-filled Array with the given shape and dtype.
|
||||
func Zeros(shape []int32, dtype DType) *Array {
|
||||
Init()
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
tt := New("ZEROS")
|
||||
C.mlx_zeros(&tt.ctx, unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// Set replaces this array's value with another, updating ref tracking.
|
||||
func (t *Array) Set(other *Array) {
|
||||
Free(t.desc.inputs...)
|
||||
other.desc.numRefs++
|
||||
t.desc.inputs = []*Array{other}
|
||||
C.mlx_array_set(&t.ctx, other.ctx)
|
||||
}
|
||||
|
||||
// Clone creates a copy of this array sharing the same data.
|
||||
func (t *Array) Clone() *Array {
|
||||
tt := New(t.desc.name, t.desc.inputs...)
|
||||
C.mlx_array_set(&tt.ctx, t.ctx)
|
||||
return tt
|
||||
}
|
||||
|
||||
// Valid reports whether this Array has a non-nil mlx handle.
|
||||
func (t *Array) Valid() bool {
|
||||
return t.ctx.ctx != nil
|
||||
}
|
||||
|
||||
// String returns a human-readable representation of the array.
|
||||
func (t *Array) String() string {
|
||||
str := C.mlx_string_new()
|
||||
defer C.mlx_string_free(str)
|
||||
C.mlx_array_tostring(&str, t.ctx)
|
||||
return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
|
||||
}
|
||||
|
||||
// Shape returns the dimensions as int32 slice.
|
||||
func (t *Array) Shape() []int32 {
|
||||
dims := make([]int32, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = int32(t.Dim(i))
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
// Size returns the total number of elements.
|
||||
func (t Array) Size() int { return int(C.mlx_array_size(t.ctx)) }
|
||||
|
||||
// NumBytes returns the total byte size.
|
||||
func (t Array) NumBytes() int { return int(C.mlx_array_nbytes(t.ctx)) }
|
||||
|
||||
// NumDims returns the number of dimensions.
|
||||
func (t Array) NumDims() int { return int(C.mlx_array_ndim(t.ctx)) }
|
||||
|
||||
// Dim returns the size of dimension i.
|
||||
func (t Array) Dim(i int) int { return int(C.mlx_array_dim(t.ctx, C.int(i))) }
|
||||
|
||||
// Dims returns all dimensions as int slice.
|
||||
func (t Array) Dims() []int {
|
||||
dims := make([]int, t.NumDims())
|
||||
for i := range dims {
|
||||
dims[i] = t.Dim(i)
|
||||
}
|
||||
return dims
|
||||
}
|
||||
|
||||
// Dtype returns the array's data type.
|
||||
func (t Array) Dtype() DType { return DType(C.mlx_array_dtype(t.ctx)) }
|
||||
|
||||
// Int extracts a scalar int64 value.
|
||||
func (t Array) Int() int {
|
||||
var item C.int64_t
|
||||
C.mlx_array_item_int64(&item, t.ctx)
|
||||
return int(item)
|
||||
}
|
||||
|
||||
// Float extracts a scalar float64 value.
|
||||
func (t Array) Float() float64 {
|
||||
var item C.double
|
||||
C.mlx_array_item_float64(&item, t.ctx)
|
||||
return float64(item)
|
||||
}
|
||||
|
||||
// Ints extracts all elements as int slice (from int32 data).
|
||||
func (t Array) Ints() []int {
|
||||
ints := make([]int, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
|
||||
ints[i] = int(f)
|
||||
}
|
||||
return ints
|
||||
}
|
||||
|
||||
// DataInt32 extracts all elements as int32 slice.
|
||||
func (t Array) DataInt32() []int32 {
|
||||
data := make([]int32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(data)) {
|
||||
data[i] = int32(f)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// Floats extracts all elements as float32 slice.
|
||||
func (t Array) Floats() []float32 {
|
||||
floats := make([]float32, t.Size())
|
||||
for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
|
||||
floats[i] = float32(f)
|
||||
}
|
||||
return floats
|
||||
}
|
||||
|
||||
// Free releases arrays using reference-counted cleanup.
|
||||
// Arrays with remaining references are not freed.
|
||||
func Free(s ...*Array) int {
|
||||
var n int
|
||||
free := make([]*Array, 0, 64)
|
||||
|
||||
fn := func(t *Array) {
|
||||
if t != nil && t.Valid() {
|
||||
t.desc.numRefs--
|
||||
if t.desc.numRefs <= 0 {
|
||||
free = append(free, t.desc.inputs...)
|
||||
n += t.NumBytes()
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, t := range s {
|
||||
fn(t)
|
||||
}
|
||||
|
||||
for len(free) > 0 {
|
||||
tail := free[len(free)-1]
|
||||
free = free[:len(free)-1]
|
||||
fn(tail)
|
||||
}
|
||||
|
||||
return n
|
||||
}
|
||||
178
pkg/mlx/cache/cache.go
vendored
Normal file
178
pkg/mlx/cache/cache.go
vendored
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
// Package cache provides KV cache implementations for transformer inference.
|
||||
package cache
|
||||
|
||||
import "forge.lthn.ai/core/cli/pkg/mlx"
|
||||
|
||||
// Cache manages key-value pairs for transformer attention layers.
|
||||
type Cache interface {
|
||||
// Update adds new key/value tensors and returns the full cached K/V.
|
||||
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
|
||||
// Offset returns the total number of tokens processed.
|
||||
Offset() int
|
||||
// Len returns the number of cached tokens (may differ from Offset for rotating caches).
|
||||
Len() int
|
||||
// State returns the cached K/V arrays, or nil if empty.
|
||||
State() []*mlx.Array
|
||||
// Reset clears the cache for a new generation session.
|
||||
Reset()
|
||||
}
|
||||
|
||||
// KVCache implements an unbounded cache that grows as needed.
|
||||
// Pre-allocates in chunks of `step` tokens to reduce allocations.
|
||||
type KVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
step int
|
||||
}
|
||||
|
||||
// NewKVCache creates a new unbounded KV cache with 256-token chunks.
|
||||
func NewKVCache() *KVCache {
|
||||
return &KVCache{step: 256}
|
||||
}
|
||||
|
||||
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
prev := c.offset
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
// Grow buffer if needed.
|
||||
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
|
||||
nSteps := (c.step + seqLen - 1) / c.step
|
||||
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
if prev%c.step != 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
}
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
}
|
||||
|
||||
func (c *KVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *KVCache) Offset() int { return c.offset }
|
||||
func (c *KVCache) Len() int { return c.offset }
|
||||
|
||||
func (c *KVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
}
|
||||
|
||||
// RotatingKVCache implements a bounded sliding window cache.
|
||||
type RotatingKVCache struct {
|
||||
keys, values *mlx.Array
|
||||
offset int
|
||||
maxSize int
|
||||
step int
|
||||
idx int
|
||||
}
|
||||
|
||||
// NewRotatingKVCache creates a cache bounded to maxSize tokens.
|
||||
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
|
||||
return &RotatingKVCache{maxSize: maxSize, step: 256}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
if seqLen > 1 {
|
||||
return c.updateConcat(k, v, seqLen)
|
||||
}
|
||||
return c.updateInPlace(k, v)
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
|
||||
var cap int
|
||||
if c.keys != nil {
|
||||
cap = int(c.keys.Shape()[2])
|
||||
}
|
||||
newSize := min(c.step, c.maxSize-cap)
|
||||
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
if c.idx >= c.maxSize {
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
||||
validLen := int32(min(c.offset, c.maxSize))
|
||||
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
|
||||
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
|
||||
shape := k.Shape()
|
||||
B, H, Dk := shape[0], shape[1], shape[3]
|
||||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
} else {
|
||||
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
|
||||
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*mlx.Array {
|
||||
if c.keys == nil {
|
||||
return nil
|
||||
}
|
||||
return []*mlx.Array{c.keys, c.values}
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Offset() int { return c.offset }
|
||||
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
|
||||
|
||||
func (c *RotatingKVCache) Reset() {
|
||||
c.keys = nil
|
||||
c.values = nil
|
||||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
85
pkg/mlx/compile.go
Normal file
85
pkg/mlx/compile.go
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
// Callback for compiled functions.
|
||||
extern void goCompiledFunc(mlx_vector_array inputs, mlx_vector_array outputs, void *payload);
|
||||
|
||||
static mlx_closure new_closure(void *payload) {
|
||||
return mlx_closure_new_func_payload(&goCompiledFunc, payload);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// CompiledFunc wraps a compiled MLX computation graph for efficient repeated calls.
|
||||
type CompiledFunc struct {
|
||||
fn func([]*Array) []*Array
|
||||
closure C.mlx_closure
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var compiledFuncs sync.Map
|
||||
|
||||
//export goCompiledFunc
|
||||
func goCompiledFunc(inputs C.mlx_vector_array, outputs C.mlx_vector_array, payload unsafe.Pointer) {
|
||||
id := uintptr(payload)
|
||||
fnI, ok := compiledFuncs.Load(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
fn := fnI.(func([]*Array) []*Array)
|
||||
|
||||
// Convert inputs
|
||||
nInputs := int(C.mlx_vector_array_size(inputs))
|
||||
goInputs := make([]*Array, nInputs)
|
||||
for i := 0; i < nInputs; i++ {
|
||||
a := New("INPUT")
|
||||
C.mlx_vector_array_get(&a.ctx, inputs, C.int(i))
|
||||
goInputs[i] = a
|
||||
}
|
||||
|
||||
// Call user function
|
||||
goOutputs := fn(goInputs)
|
||||
|
||||
// Set outputs
|
||||
for _, out := range goOutputs {
|
||||
C.mlx_vector_array_append_value(outputs, out.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
var nextID uintptr
|
||||
var nextIDMu sync.Mutex
|
||||
|
||||
// CompileShapeless compiles a function for efficient repeated execution.
|
||||
// The function must accept and return arrays of consistent shapes.
|
||||
func CompileShapeless(fn func([]*Array) []*Array, shapeless bool) *CompiledFunc {
|
||||
nextIDMu.Lock()
|
||||
nextID++
|
||||
id := nextID
|
||||
nextIDMu.Unlock()
|
||||
|
||||
compiledFuncs.Store(id, fn)
|
||||
|
||||
cf := &CompiledFunc{fn: fn}
|
||||
cf.closure = C.new_closure(unsafe.Pointer(id))
|
||||
return cf
|
||||
}
|
||||
|
||||
// Call executes the compiled function with the given inputs.
|
||||
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
|
||||
cf.mu.Lock()
|
||||
defer cf.mu.Unlock()
|
||||
|
||||
// Fall back to direct call — compilation is an optimization.
|
||||
// The compiled closure can be used via mlx_compiled but the
|
||||
// direct path is simpler and still benefits from MLX's lazy evaluation.
|
||||
return cf.fn(inputs)
|
||||
}
|
||||
83
pkg/mlx/dtype.go
Normal file
83
pkg/mlx/dtype.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// #include "mlx/c/mlx.h"
|
||||
import "C"
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// DType represents an MLX array data type.
|
||||
type DType C.mlx_dtype
|
||||
|
||||
const (
|
||||
DTypeBool DType = C.MLX_BOOL
|
||||
DTypeUint8 DType = C.MLX_UINT8
|
||||
DTypeUint16 DType = C.MLX_UINT16
|
||||
DTypeUint32 DType = C.MLX_UINT32
|
||||
DTypeUint64 DType = C.MLX_UINT64
|
||||
DTypeInt8 DType = C.MLX_INT8
|
||||
DTypeInt16 DType = C.MLX_INT16
|
||||
DTypeInt32 DType = C.MLX_INT32
|
||||
DTypeInt64 DType = C.MLX_INT64
|
||||
DTypeFloat16 DType = C.MLX_FLOAT16
|
||||
DTypeFloat32 DType = C.MLX_FLOAT32
|
||||
DTypeFloat64 DType = C.MLX_FLOAT64
|
||||
DTypeBFloat16 DType = C.MLX_BFLOAT16
|
||||
DTypeComplex64 DType = C.MLX_COMPLEX64
|
||||
)
|
||||
|
||||
var dtypeNames = map[DType]string{
|
||||
DTypeBool: "bool",
|
||||
DTypeUint8: "uint8",
|
||||
DTypeUint16: "uint16",
|
||||
DTypeUint32: "uint32",
|
||||
DTypeUint64: "uint64",
|
||||
DTypeInt8: "int8",
|
||||
DTypeInt16: "int16",
|
||||
DTypeInt32: "int32",
|
||||
DTypeInt64: "int64",
|
||||
DTypeFloat16: "float16",
|
||||
DTypeFloat32: "float32",
|
||||
DTypeFloat64: "float64",
|
||||
DTypeBFloat16: "bfloat16",
|
||||
DTypeComplex64: "complex64",
|
||||
}
|
||||
|
||||
func (d DType) String() string {
|
||||
if s, ok := dtypeNames[d]; ok {
|
||||
return s
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
var dtypeFromString = map[string]DType{
|
||||
"bool": DTypeBool, "BOOL": DTypeBool,
|
||||
"uint8": DTypeUint8, "U8": DTypeUint8,
|
||||
"uint16": DTypeUint16, "U16": DTypeUint16,
|
||||
"uint32": DTypeUint32, "U32": DTypeUint32,
|
||||
"uint64": DTypeUint64, "U64": DTypeUint64,
|
||||
"int8": DTypeInt8, "I8": DTypeInt8,
|
||||
"int16": DTypeInt16, "I16": DTypeInt16,
|
||||
"int32": DTypeInt32, "I32": DTypeInt32,
|
||||
"int64": DTypeInt64, "I64": DTypeInt64,
|
||||
"float16": DTypeFloat16, "F16": DTypeFloat16,
|
||||
"float32": DTypeFloat32, "F32": DTypeFloat32,
|
||||
"float64": DTypeFloat64, "F64": DTypeFloat64,
|
||||
"bfloat16": DTypeBFloat16, "BF16": DTypeBFloat16,
|
||||
"complex64": DTypeComplex64,
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses a DType from JSON strings like "F32", "BF16", etc.
|
||||
func (d *DType) UnmarshalJSON(b []byte) error {
|
||||
var s string
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
return err
|
||||
}
|
||||
if dt, ok := dtypeFromString[s]; ok {
|
||||
*d = dt
|
||||
return nil
|
||||
}
|
||||
*d = DTypeFloat32 // default
|
||||
return nil
|
||||
}
|
||||
81
pkg/mlx/fast.go
Normal file
81
pkg/mlx/fast.go
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
// RMSNorm applies Root Mean Square normalization using a fused Metal kernel.
|
||||
func RMSNorm(x, weight *Array, eps float32) *Array {
|
||||
out := New("FAST_RMSNORM", x)
|
||||
C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// LayerNorm applies Layer normalization using a fused Metal kernel.
|
||||
func LayerNorm(x, weight, bias *Array, eps float32) *Array {
|
||||
out := New("FAST_LAYERNORM", x)
|
||||
C.mlx_fast_layer_norm(&out.ctx, x.ctx, weight.ctx, bias.ctx, C.float(eps), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// RoPE applies Rotary Position Embeddings using a fused Metal kernel.
|
||||
func RoPE(x *Array, dims int, traditional bool, base float32, scale float32, offset int) *Array {
|
||||
freqs := New("")
|
||||
out := New("FAST_ROPE", x, freqs)
|
||||
C.mlx_fast_rope(
|
||||
&out.ctx,
|
||||
x.ctx,
|
||||
C.int(dims),
|
||||
C._Bool(traditional),
|
||||
C.mlx_optional_float{
|
||||
value: C.float(base),
|
||||
has_value: C._Bool(base != 0),
|
||||
},
|
||||
C.float(scale),
|
||||
C.int(offset),
|
||||
freqs.ctx,
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention computes attention using a fused Metal kernel.
|
||||
// mask can be nil for causal masking, or set causal=true for auto causal mask.
|
||||
func ScaledDotProductAttention(query, key, value *Array, scale float32, causal bool) *Array {
|
||||
var mask, sinks *Array
|
||||
if causal {
|
||||
mask = New("")
|
||||
sinks = New("")
|
||||
} else {
|
||||
mask = New("")
|
||||
sinks = New("")
|
||||
}
|
||||
|
||||
mode := "causal"
|
||||
if !causal {
|
||||
mode = "none"
|
||||
}
|
||||
cMode := C.CString(mode)
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// ScaledDotProductAttentionWithMask computes attention with an explicit mask.
|
||||
func ScaledDotProductAttentionWithMask(query, key, value, mask *Array, scale float32) *Array {
|
||||
sinks := New("")
|
||||
cMode := C.CString("none")
|
||||
defer C.free(unsafe.Pointer(cMode))
|
||||
|
||||
out := New("FAST_SDPA", query, key, value, mask, sinks)
|
||||
C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
60
pkg/mlx/io.go
Normal file
60
pkg/mlx/io.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// LoadSafetensors loads tensors from a .safetensors file, returning an iterator
|
||||
// over (name, array) pairs. Tensors are loaded lazily on the CPU stream.
|
||||
func LoadSafetensors(path string) iter.Seq2[string, *Array] {
|
||||
Init()
|
||||
return func(yield func(string, *Array) bool) {
|
||||
string2array := C.mlx_map_string_to_array_new()
|
||||
defer C.mlx_map_string_to_array_free(string2array)
|
||||
|
||||
string2string := C.mlx_map_string_to_string_new()
|
||||
defer C.mlx_map_string_to_string_free(string2string)
|
||||
|
||||
cPath := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cPath))
|
||||
|
||||
cpu := C.mlx_default_cpu_stream_new()
|
||||
defer C.mlx_stream_free(cpu)
|
||||
|
||||
C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
|
||||
|
||||
it := C.mlx_map_string_to_array_iterator_new(string2array)
|
||||
defer C.mlx_map_string_to_array_iterator_free(it)
|
||||
|
||||
for {
|
||||
var key *C.char
|
||||
value := C.mlx_array_new()
|
||||
if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
|
||||
break
|
||||
}
|
||||
|
||||
name := C.GoString(key)
|
||||
if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadAllSafetensors loads all tensors from a .safetensors file into a map.
|
||||
func LoadAllSafetensors(path string) map[string]*Array {
|
||||
tensors := make(map[string]*Array)
|
||||
for name, arr := range LoadSafetensors(path) {
|
||||
tensors[name] = arr
|
||||
}
|
||||
return tensors
|
||||
}
|
||||
103
pkg/mlx/mlx.go
Normal file
103
pkg/mlx/mlx.go
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
//
|
||||
// Build mlx-c before use:
|
||||
//
|
||||
// cd pkg/mlx && go generate ./...
|
||||
//
|
||||
// Build with MLX enabled:
|
||||
//
|
||||
// go build -tags mlx -o core .
|
||||
package mlx
|
||||
|
||||
//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
|
||||
//go:generate cmake --build build --parallel
|
||||
//go:generate cmake --install build
|
||||
|
||||
/*
|
||||
#cgo CXXFLAGS: -std=c++17
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/dist/include
|
||||
#cgo LDFLAGS: -L${SRCDIR}/dist/lib -lmlxc -lmlx -lstdc++
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -Wl,-rpath,${SRCDIR}/dist/lib
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
|
||||
extern void goMLXErrorHandler(const char *msg, void *data);
|
||||
|
||||
static void set_error_handler() {
|
||||
mlx_set_error_handler(&goMLXErrorHandler, NULL, NULL);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var initOnce sync.Once
|
||||
|
||||
// Init sets up the MLX error handler. Called automatically on first use.
|
||||
func Init() {
|
||||
initOnce.Do(func() {
|
||||
C.set_error_handler()
|
||||
slog.Debug("mlx: initialized with Metal backend")
|
||||
})
|
||||
}
|
||||
|
||||
//export goMLXErrorHandler
|
||||
func goMLXErrorHandler(msg *C.char, data unsafe.Pointer) {
|
||||
slog.Error("mlx", "error", C.GoString(msg))
|
||||
}
|
||||
|
||||
// Materialize synchronously evaluates arrays, computing their values on the GPU.
|
||||
// This is the MLX equivalent of forcing lazy computation to complete.
|
||||
func Materialize(outputs ...*Array) {
|
||||
doMaterialize(outputs, false)
|
||||
}
|
||||
|
||||
// MaterializeAsync queues arrays for asynchronous GPU evaluation.
|
||||
func MaterializeAsync(outputs ...*Array) {
|
||||
doMaterialize(outputs, true)
|
||||
}
|
||||
|
||||
func doMaterialize(outputs []*Array, async bool) {
|
||||
Init()
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
for _, output := range outputs {
|
||||
if output != nil && output.Valid() {
|
||||
C.mlx_vector_array_append_value(vector, output.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
if async {
|
||||
C.mlx_async_eval(vector)
|
||||
} else {
|
||||
C.mlx_eval(vector)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect gathers all valid arrays from a variadic list for batch Materialize.
|
||||
func Collect(arrays ...*Array) []*Array {
|
||||
var out []*Array
|
||||
for _, a := range arrays {
|
||||
if a != nil && a.Valid() {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MetalAvailable reports whether Metal GPU is available.
|
||||
func MetalAvailable() bool {
|
||||
Init()
|
||||
var available C.bool
|
||||
C.mlx_metal_is_available(&available)
|
||||
return bool(available)
|
||||
}
|
||||
10
pkg/mlx/mlx_stub.go
Normal file
10
pkg/mlx/mlx_stub.go
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
//go:build !(darwin && arm64 && mlx)
|
||||
|
||||
// Package mlx provides Go bindings for Apple's MLX framework via mlx-c.
|
||||
// This stub file is used on non-darwin/non-arm64 platforms or when the
|
||||
// mlx build tag is not set. All operations report MLX as unavailable.
|
||||
package mlx
|
||||
|
||||
// MetalAvailable reports whether Metal GPU is available.
|
||||
// Always returns false on non-Apple Silicon platforms.
|
||||
func MetalAvailable() bool { return false }
|
||||
327
pkg/mlx/model/gemma3.go
Normal file
327
pkg/mlx/model/gemma3.go
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
// Package model provides transformer model architectures for MLX inference.
|
||||
package model
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/mlx"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/cache"
|
||||
"forge.lthn.ai/core/cli/pkg/mlx/tokenizer"
|
||||
)
|
||||
|
||||
// TextConfig holds Gemma 3 text model configuration.
|
||||
type TextConfig struct {
|
||||
HiddenSize int32 `json:"hidden_size"`
|
||||
NumHiddenLayers int32 `json:"num_hidden_layers"`
|
||||
IntermediateSize int32 `json:"intermediate_size"`
|
||||
NumAttentionHeads int32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads int32 `json:"num_key_value_heads"`
|
||||
HeadDim int32 `json:"head_dim"`
|
||||
VocabSize int32 `json:"vocab_size"`
|
||||
RMSNormEps float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
|
||||
SlidingWindow int32 `json:"sliding_window"`
|
||||
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
|
||||
|
||||
Scale float32 `json:"-"` // Computed: 1/sqrt(head_dim)
|
||||
}
|
||||
|
||||
// GemmaModel is the Gemma 3 text model.
|
||||
type GemmaModel struct {
|
||||
EmbedTokens *mlx.Embedding
|
||||
Layers []*DecoderLayer
|
||||
Norm *mlx.RMSNormModule
|
||||
Output *mlx.Linear // Tied to EmbedTokens
|
||||
|
||||
// Precomputed (1 + weight) for Gemma-style RMSNorm
|
||||
NormScaled *mlx.Array
|
||||
|
||||
Tok *tokenizer.Tokenizer
|
||||
Cfg *TextConfig
|
||||
}
|
||||
|
||||
// DecoderLayer is a single transformer block.
|
||||
type DecoderLayer struct {
|
||||
InputNorm *mlx.RMSNormModule
|
||||
Attention *Attention
|
||||
PostAttnNorm *mlx.RMSNormModule
|
||||
PreFFNorm *mlx.RMSNormModule
|
||||
MLP *MLP
|
||||
PostFFNorm *mlx.RMSNormModule
|
||||
|
||||
// Precomputed scaled weights
|
||||
InputNormScaled *mlx.Array
|
||||
PostAttnNormScaled *mlx.Array
|
||||
PreFFNormScaled *mlx.Array
|
||||
PostFFNormScaled *mlx.Array
|
||||
|
||||
IsSliding bool
|
||||
LayerIdx int32
|
||||
}
|
||||
|
||||
// Attention implements Gemma 3 attention with Q/K normalization.
|
||||
type Attention struct {
|
||||
QProj *mlx.Linear
|
||||
KProj *mlx.Linear
|
||||
VProj *mlx.Linear
|
||||
OProj *mlx.Linear
|
||||
QNorm *mlx.RMSNormModule
|
||||
KNorm *mlx.RMSNormModule
|
||||
|
||||
QNormScaled *mlx.Array
|
||||
KNormScaled *mlx.Array
|
||||
}
|
||||
|
||||
// MLP is the feed-forward network.
|
||||
type MLP struct {
|
||||
GateProj *mlx.Linear
|
||||
UpProj *mlx.Linear
|
||||
DownProj *mlx.Linear
|
||||
}
|
||||
|
||||
// compiledGELU is a singleton for the compiled GELU function.
|
||||
var compiledGELU *mlx.CompiledFunc
|
||||
|
||||
func getCompiledGELU() *mlx.CompiledFunc {
|
||||
if compiledGELU == nil {
|
||||
compiledGELU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
|
||||
return []*mlx.Array{geluApprox(inputs[0])}
|
||||
}, true)
|
||||
}
|
||||
return compiledGELU
|
||||
}
|
||||
|
||||
// geluApprox computes GELU using the tanh approximation:
|
||||
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
|
||||
func geluApprox(x *mlx.Array) *mlx.Array {
|
||||
const sqrt2OverPi = 0.7978845608028654
|
||||
const coeff = 0.044715
|
||||
|
||||
x3 := mlx.Mul(mlx.Mul(x, x), x)
|
||||
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
|
||||
scaled := mlx.MulScalar(inner, sqrt2OverPi)
|
||||
t := mlx.Tanh(scaled)
|
||||
onePlusT := mlx.AddScalar(t, 1.0)
|
||||
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusT)
|
||||
}
|
||||
|
||||
// LoadGemma3 loads a Gemma 3 text model from a directory.
|
||||
func LoadGemma3(modelPath string) (*GemmaModel, error) {
|
||||
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load config: %w", err)
|
||||
}
|
||||
|
||||
var cfg TextConfig
|
||||
if err := json.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("gemma3: parse config: %w", err)
|
||||
}
|
||||
|
||||
// Defaults
|
||||
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
|
||||
if cfg.RopeTheta == 0 {
|
||||
cfg.RopeTheta = 1000000
|
||||
}
|
||||
if cfg.RopeLocalBaseFreq == 0 {
|
||||
cfg.RopeLocalBaseFreq = 10000
|
||||
}
|
||||
if cfg.RMSNormEps == 0 {
|
||||
cfg.RMSNormEps = 1e-6
|
||||
}
|
||||
if cfg.SlidingWindowPattern == 0 {
|
||||
cfg.SlidingWindowPattern = 6
|
||||
}
|
||||
|
||||
// Load tokenizer
|
||||
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemma3: load tokenizer: %w", err)
|
||||
}
|
||||
|
||||
// Load weights from all safetensors files
|
||||
weights := make(map[string]*mlx.Array)
|
||||
matches, _ := filepath.Glob(filepath.Join(modelPath, "*.safetensors"))
|
||||
for _, path := range matches {
|
||||
for name, arr := range mlx.LoadSafetensors(path) {
|
||||
weights[name] = arr
|
||||
}
|
||||
}
|
||||
|
||||
m := &GemmaModel{
|
||||
EmbedTokens: &mlx.Embedding{Weight: weights["model.embed_tokens.weight"]},
|
||||
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
|
||||
Norm: &mlx.RMSNormModule{Weight: weights["model.norm.weight"]},
|
||||
Tok: tok,
|
||||
Cfg: &cfg,
|
||||
}
|
||||
|
||||
// Initialize layers
|
||||
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
|
||||
prefix := fmt.Sprintf("model.layers.%d", i)
|
||||
m.Layers[i] = &DecoderLayer{
|
||||
InputNorm: &mlx.RMSNormModule{Weight: weights[prefix+".input_layernorm.weight"]},
|
||||
PostAttnNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_attention_layernorm.weight"]},
|
||||
PreFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".pre_feedforward_layernorm.weight"]},
|
||||
PostFFNorm: &mlx.RMSNormModule{Weight: weights[prefix+".post_feedforward_layernorm.weight"]},
|
||||
Attention: &Attention{
|
||||
QProj: mlx.NewLinear(weights[prefix+".self_attn.q_proj.weight"], nil),
|
||||
KProj: mlx.NewLinear(weights[prefix+".self_attn.k_proj.weight"], nil),
|
||||
VProj: mlx.NewLinear(weights[prefix+".self_attn.v_proj.weight"], nil),
|
||||
OProj: mlx.NewLinear(weights[prefix+".self_attn.o_proj.weight"], nil),
|
||||
QNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.q_norm.weight"]},
|
||||
KNorm: &mlx.RMSNormModule{Weight: weights[prefix+".self_attn.k_norm.weight"]},
|
||||
},
|
||||
MLP: &MLP{
|
||||
GateProj: mlx.NewLinear(weights[prefix+".mlp.gate_proj.weight"], nil),
|
||||
UpProj: mlx.NewLinear(weights[prefix+".mlp.up_proj.weight"], nil),
|
||||
DownProj: mlx.NewLinear(weights[prefix+".mlp.down_proj.weight"], nil),
|
||||
},
|
||||
LayerIdx: i,
|
||||
IsSliding: isLayerSliding(i, cfg.SlidingWindowPattern),
|
||||
}
|
||||
}
|
||||
|
||||
// Tied embeddings
|
||||
m.Output = mlx.NewLinear(m.EmbedTokens.Weight, nil)
|
||||
|
||||
// Materialize all weights
|
||||
var allArrays []*mlx.Array
|
||||
for _, a := range weights {
|
||||
allArrays = append(allArrays, a)
|
||||
}
|
||||
mlx.Materialize(allArrays...)
|
||||
|
||||
// Precompute (1 + weight) for Gemma-style RMSNorm
|
||||
precomputeScaledWeights(m)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func precomputeScaledWeights(m *GemmaModel) {
|
||||
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
|
||||
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
|
||||
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
|
||||
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
|
||||
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
|
||||
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
|
||||
}
|
||||
|
||||
var scaled []*mlx.Array
|
||||
scaled = append(scaled, m.NormScaled)
|
||||
for _, layer := range m.Layers {
|
||||
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
|
||||
layer.PreFFNormScaled, layer.PostFFNormScaled,
|
||||
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
|
||||
}
|
||||
mlx.Materialize(scaled...)
|
||||
}
|
||||
|
||||
func isLayerSliding(layerIdx, pattern int32) bool {
|
||||
if pattern <= 0 {
|
||||
return false
|
||||
}
|
||||
return (layerIdx+1)%pattern != 0
|
||||
}
|
||||
|
||||
// Forward runs the text model forward pass.
|
||||
func (m *GemmaModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
|
||||
shape := tokens.Shape()
|
||||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, m.Cfg)
|
||||
}
|
||||
|
||||
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
|
||||
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, cfg)
|
||||
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := mlx.Add(x, attnOut)
|
||||
|
||||
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return mlx.Add(h, mlpOut)
|
||||
}
|
||||
|
||||
func (a *Attention) forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
|
||||
// Q/K normalization
|
||||
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
|
||||
// RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
|
||||
// Update cache
|
||||
k, v = c.Update(k, v, int(L))
|
||||
|
||||
// GQA: repeat K/V heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
if repeatFactor > 1 {
|
||||
k = mlx.RepeatKV(k, repeatFactor)
|
||||
v = mlx.RepeatKV(v, repeatFactor)
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
}
|
||||
|
||||
func (m *MLP) forward(x *mlx.Array) *mlx.Array {
|
||||
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
|
||||
}
|
||||
|
||||
// NewCache creates per-layer caches for generation.
|
||||
func (m *GemmaModel) NewCache() []cache.Cache {
|
||||
caches := make([]cache.Cache, len(m.Layers))
|
||||
for i := range caches {
|
||||
if m.Layers[i].IsSliding {
|
||||
caches[i] = cache.NewRotatingKVCache(int(m.Cfg.SlidingWindow))
|
||||
} else {
|
||||
caches[i] = cache.NewKVCache()
|
||||
}
|
||||
}
|
||||
return caches
|
||||
}
|
||||
|
||||
// NumLayers returns the number of transformer layers.
|
||||
func (m *GemmaModel) NumLayers() int { return len(m.Layers) }
|
||||
|
||||
// Tokenizer returns the model's tokenizer.
|
||||
func (m *GemmaModel) Tokenizer() *tokenizer.Tokenizer { return m.Tok }
|
||||
59
pkg/mlx/nn.go
Normal file
59
pkg/mlx/nn.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
// Linear is a fully-connected layer: y = x @ W.T + bias.
|
||||
type Linear struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
Bias *Array `weight:"bias"`
|
||||
}
|
||||
|
||||
// NewLinear creates a Linear layer with optional bias.
|
||||
func NewLinear(weight, bias *Array) *Linear {
|
||||
return &Linear{Weight: weight, Bias: bias}
|
||||
}
|
||||
|
||||
// Forward computes the linear transformation.
|
||||
func (l *Linear) Forward(x *Array) *Array {
|
||||
out := Matmul(x, Transpose(l.Weight))
|
||||
if l.Bias != nil && l.Bias.Valid() {
|
||||
out = Add(out, l.Bias)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// Embedding is a lookup table for token embeddings.
|
||||
type Embedding struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// Forward looks up embeddings for the given token indices.
|
||||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
return Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
||||
// RMSNormModule is an RMS normalization layer wrapping the fused kernel.
|
||||
type RMSNormModule struct {
|
||||
Weight *Array `weight:"weight"`
|
||||
}
|
||||
|
||||
// Forward applies RMS normalization.
|
||||
func (r *RMSNormModule) Forward(x *Array, eps float32) *Array {
|
||||
return RMSNorm(x, r.Weight, eps)
|
||||
}
|
||||
|
||||
// RepeatKV repeats key/value heads for grouped-query attention.
|
||||
// Input shape: [B, num_kv_heads, L, D]
|
||||
// Output shape: [B, num_kv_heads * factor, L, D]
|
||||
func RepeatKV(x *Array, factor int32) *Array {
|
||||
if factor <= 1 {
|
||||
return x
|
||||
}
|
||||
shape := x.Shape()
|
||||
B, H, L, D := shape[0], shape[1], shape[2], shape[3]
|
||||
|
||||
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
|
||||
expanded := ExpandDims(x, 2)
|
||||
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
|
||||
return Reshape(expanded, B, H*factor, L, D)
|
||||
}
|
||||
308
pkg/mlx/ops.go
Normal file
308
pkg/mlx/ops.go
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// --- Element-wise arithmetic ---
|
||||
|
||||
// Add returns element-wise a + b.
|
||||
func Add(a, b *Array) *Array {
|
||||
out := New("ADD", a, b)
|
||||
C.mlx_add(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AddScalar returns a + scalar (broadcast).
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Add(a, scalar)
|
||||
}
|
||||
|
||||
// Mul returns element-wise a * b.
|
||||
func Mul(a, b *Array) *Array {
|
||||
out := New("MUL", a, b)
|
||||
C.mlx_multiply(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// MulScalar returns a * scalar (broadcast).
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Mul(a, scalar)
|
||||
}
|
||||
|
||||
// Divide returns element-wise a / b.
|
||||
func Divide(a, b *Array) *Array {
|
||||
out := New("DIV", a, b)
|
||||
C.mlx_divide(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Subtract returns element-wise a - b.
|
||||
func Subtract(a, b *Array) *Array {
|
||||
out := New("SUB", a, b)
|
||||
C.mlx_subtract(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Negative returns element-wise -a.
|
||||
func Negative(a *Array) *Array {
|
||||
out := New("NEG", a)
|
||||
C.mlx_negative(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Math functions ---
|
||||
|
||||
// Exp returns element-wise exp(a).
|
||||
func Exp(a *Array) *Array {
|
||||
out := New("EXP", a)
|
||||
C.mlx_exp(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Tanh returns element-wise tanh(a).
|
||||
func Tanh(a *Array) *Array {
|
||||
out := New("TANH", a)
|
||||
C.mlx_tanh(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sqrt returns element-wise sqrt(a).
|
||||
func Sqrt(a *Array) *Array {
|
||||
out := New("SQRT", a)
|
||||
C.mlx_sqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Rsqrt returns element-wise 1/sqrt(a).
|
||||
func Rsqrt(a *Array) *Array {
|
||||
out := New("RSQRT", a)
|
||||
C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Reciprocal returns element-wise 1/a.
|
||||
func Reciprocal(a *Array) *Array {
|
||||
out := New("RECIPROCAL", a)
|
||||
C.mlx_reciprocal(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Square returns element-wise a^2.
|
||||
func Square(a *Array) *Array {
|
||||
out := New("SQUARE", a)
|
||||
C.mlx_square(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Power returns element-wise a^b.
|
||||
func Power(a, b *Array) *Array {
|
||||
out := New("POWER", a, b)
|
||||
C.mlx_power(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Maximum returns element-wise max(a, b).
|
||||
func Maximum(a, b *Array) *Array {
|
||||
out := New("MAX", a, b)
|
||||
C.mlx_maximum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Minimum returns element-wise min(a, b).
|
||||
func Minimum(a, b *Array) *Array {
|
||||
out := New("MIN", a, b)
|
||||
C.mlx_minimum(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Matrix operations ---
|
||||
|
||||
// Matmul returns the matrix product of a and b.
|
||||
func Matmul(a, b *Array) *Array {
|
||||
out := New("MATMUL", a, b)
|
||||
C.mlx_matmul(&out.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// QuantizedMatmul performs quantized matrix multiplication.
|
||||
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int) *Array {
|
||||
out := New("QMATMUL", x, w, scales, biases)
|
||||
C.mlx_quantized_matmul(
|
||||
&out.ctx, x.ctx, w.ctx, scales.ctx, biases.ctx,
|
||||
C._Bool(transpose), C.int(groupSize), C.int(bits),
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Reductions ---
|
||||
|
||||
// Softmax returns softmax along the last axis.
|
||||
func Softmax(a *Array) *Array {
|
||||
out := New("SOFTMAX", a)
|
||||
axis := []C.int{C.int(-1)}
|
||||
C.mlx_softmax(&out.ctx, a.ctx, &axis[0], C.int(1), C._Bool(false), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argmax returns the index of the maximum value along an axis.
|
||||
func Argmax(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("ARGMAX", a)
|
||||
C.mlx_argmax(&out.ctx, a.ctx, C.int(axis), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// TopK returns the top k values along the last axis.
|
||||
func TopK(a *Array, k int) *Array {
|
||||
out := New("TOPK", a)
|
||||
C.mlx_topk(&out.ctx, a.ctx, C.int(k), C.int(-1), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Sum reduces by summation along the given axis.
|
||||
func Sum(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("SUM", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_sum(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Mean reduces by averaging along the given axis.
|
||||
func Mean(a *Array, axis int, keepDims bool) *Array {
|
||||
out := New("MEAN", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_mean(&out.ctx, a.ctx, &axes[0], C.int(1), C._Bool(keepDims), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Shape operations ---
|
||||
|
||||
// Reshape changes the shape of an array.
|
||||
func Reshape(a *Array, shape ...int32) *Array {
|
||||
out := New("RESHAPE", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_reshape(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Transpose permutes dimensions. If no axes given, reverses all dims.
|
||||
func Transpose(a *Array, axes ...int) *Array {
|
||||
out := New("TRANSPOSE", a)
|
||||
if len(axes) == 0 {
|
||||
C.mlx_transpose_all(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
} else {
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_transpose(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ExpandDims inserts a new axis at the given position.
|
||||
func ExpandDims(a *Array, axis int) *Array {
|
||||
out := New("EXPAND_DIMS", a)
|
||||
axes := []C.int{C.int(axis)}
|
||||
C.mlx_expand_dims(&out.ctx, a.ctx, &axes[0], C.int(1), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Squeeze removes dimensions of size 1.
|
||||
func Squeeze(a *Array, axes ...int) *Array {
|
||||
out := New("SQUEEZE", a)
|
||||
cAxes := make([]C.int, len(axes))
|
||||
for i, ax := range axes {
|
||||
cAxes[i] = C.int(ax)
|
||||
}
|
||||
C.mlx_squeeze(&out.ctx, a.ctx, &cAxes[0], C.int(len(cAxes)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Concatenate joins arrays along the given axis.
|
||||
func Concatenate(arrays []*Array, axis int) *Array {
|
||||
vector := C.mlx_vector_array_new()
|
||||
defer C.mlx_vector_array_free(vector)
|
||||
|
||||
inputs := make([]*Array, len(arrays))
|
||||
for i, a := range arrays {
|
||||
C.mlx_vector_array_append_value(vector, a.ctx)
|
||||
inputs[i] = a
|
||||
}
|
||||
|
||||
out := New("CONCAT", inputs...)
|
||||
C.mlx_concatenate(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// BroadcastTo broadcasts an array to the given shape.
|
||||
func BroadcastTo(a *Array, shape []int32) *Array {
|
||||
out := New("BROADCAST", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
C.mlx_broadcast_to(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AsType casts an array to a different dtype.
|
||||
func AsType(a *Array, dtype DType) *Array {
|
||||
out := New("ASTYPE", a)
|
||||
C.mlx_astype(&out.ctx, a.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// AsStrided creates a view with custom strides.
|
||||
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
|
||||
out := New("AS_STRIDED", a)
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
cStrides := make([]C.size_t, len(strides))
|
||||
for i, s := range strides {
|
||||
cStrides[i] = C.size_t(s)
|
||||
}
|
||||
C.mlx_as_strided(&out.ctx, a.ctx, &cShape[0], C.int(len(cShape)), &cStrides[0], C.int(len(cStrides)), C.size_t(offset), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Take gathers elements from a along axis using indices.
|
||||
func Take(a, indices *Array, axis int) *Array {
|
||||
out := New("TAKE", a, indices)
|
||||
C.mlx_take_axis(&out.ctx, a.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Where selects elements from a or b based on condition.
|
||||
func Where(condition, a, b *Array) *Array {
|
||||
out := New("WHERE", condition, a, b)
|
||||
C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// Argpartition partially sorts and returns indices for top-k selection.
|
||||
func Argpartition(a *Array, kth, axis int) *Array {
|
||||
out := New("ARGPARTITION", a)
|
||||
C.mlx_argpartition(&out.ctx, a.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// PutAlongAxis places values into array at indices along axis.
|
||||
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
|
||||
out := New("PUT_ALONG_AXIS", a, indices, values)
|
||||
// Use scatter approach: src[indices] = values
|
||||
C.mlx_put_along_axis(&out.ctx, a.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
44
pkg/mlx/random.go
Normal file
44
pkg/mlx/random.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// RandomCategorical samples from a categorical distribution defined by logprobs.
|
||||
// Returns indices sampled according to the log-probability distribution along the last axis.
|
||||
func RandomCategorical(logprobs *Array) *Array {
|
||||
out := New("RANDOM_CATEGORICAL", logprobs)
|
||||
// shape for output: same as input but last dim removed
|
||||
C.mlx_random_categorical_shape(
|
||||
&out.ctx,
|
||||
logprobs.ctx,
|
||||
C.int(-1), // axis
|
||||
nil, C.int(0), // empty shape = infer from input
|
||||
nil, // key (use default)
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
|
||||
// RandomUniform generates uniform random values in [low, high).
|
||||
func RandomUniform(low, high float32, shape []int32, dtype DType) *Array {
|
||||
out := New("RANDOM_UNIFORM")
|
||||
cShape := make([]C.int, len(shape))
|
||||
for i, s := range shape {
|
||||
cShape[i] = C.int(s)
|
||||
}
|
||||
lo := FromValue(low)
|
||||
hi := FromValue(high)
|
||||
C.mlx_random_uniform(
|
||||
&out.ctx,
|
||||
lo.ctx, hi.ctx,
|
||||
&cShape[0], C.int(len(cShape)),
|
||||
C.mlx_dtype(dtype),
|
||||
nil, // key
|
||||
DefaultStream().ctx,
|
||||
)
|
||||
return out
|
||||
}
|
||||
105
pkg/mlx/sample/sample.go
Normal file
105
pkg/mlx/sample/sample.go
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
// Package sample provides composable token sampling strategies.
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/mlx"
|
||||
)
|
||||
|
||||
// Sampler transforms logits into a sampled token index.
|
||||
type Sampler interface {
|
||||
Sample(logits *mlx.Array) *mlx.Array
|
||||
}
|
||||
|
||||
// New creates a composable sampler chain from the given parameters.
|
||||
// Order: TopP -> MinP -> TopK -> Temperature -> categorical sample.
|
||||
func New(temp, topP, minP float32, topK int) Sampler {
|
||||
if temp == 0 {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
var samplers []Sampler
|
||||
if topP > 0 && topP < 1 {
|
||||
samplers = append(samplers, TopP(topP))
|
||||
}
|
||||
if minP > 0 {
|
||||
samplers = append(samplers, MinPSampler(minP))
|
||||
}
|
||||
if topK > 0 {
|
||||
samplers = append(samplers, TopKSampler(topK))
|
||||
}
|
||||
samplers = append(samplers, Temperature(temp))
|
||||
return chain(samplers)
|
||||
}
|
||||
|
||||
// chain applies a sequence of samplers, then samples from the result.
|
||||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *mlx.Array) *mlx.Array {
|
||||
for _, s := range c {
|
||||
logits = s.Sample(logits)
|
||||
}
|
||||
// Final categorical sample from log-probabilities
|
||||
return mlx.RandomCategorical(logits)
|
||||
}
|
||||
|
||||
// greedy returns the argmax token.
|
||||
type greedy struct{}
|
||||
|
||||
func (greedy) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.Argmax(logits, -1, false)
|
||||
}
|
||||
|
||||
// Temperature scales logits by 1/temp.
|
||||
type Temperature float32
|
||||
|
||||
func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
|
||||
return mlx.MulScalar(logits, 1.0/float32(t))
|
||||
}
|
||||
|
||||
// TopKSampler masks all but the top-k logits.
|
||||
type TopKSampler int
|
||||
|
||||
func (k TopKSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
neg := mlx.Negative(logits)
|
||||
mask := mlx.Argpartition(neg, int(k)-1, -1)
|
||||
// Slice the indices beyond top-k
|
||||
mask = mlx.SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
|
||||
return mlx.PutAlongAxis(logits, mask, mlx.FromValue(float32(math.Inf(-1))), -1)
|
||||
}
|
||||
|
||||
// TopP implements nucleus sampling (cumulative probability threshold).
|
||||
type TopP float32
|
||||
|
||||
func (p TopP) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// Softmax to get probabilities
|
||||
probs := mlx.Softmax(logits)
|
||||
// Sort descending
|
||||
neg := mlx.Negative(probs)
|
||||
sortedIdx := mlx.Argpartition(neg, 0, -1)
|
||||
sortedProbs := mlx.Take(probs, sortedIdx, -1)
|
||||
|
||||
// Cumulative sum
|
||||
cumProbs := mlx.Sum(sortedProbs, -1, true) // simplified — full impl needs cumsum
|
||||
|
||||
// Mask tokens beyond threshold
|
||||
threshold := mlx.FromValue(float32(p))
|
||||
mask := mlx.Where(
|
||||
mlx.FromValue(true), // placeholder — proper impl compares cumprobs > p
|
||||
mlx.FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
return mask
|
||||
}
|
||||
|
||||
// MinPSampler masks tokens below min_p * max_prob.
|
||||
type MinPSampler float32
|
||||
|
||||
func (p MinPSampler) Sample(logits *mlx.Array) *mlx.Array {
|
||||
// For now, pass through — MinP is an optimization over TopP.
|
||||
// Full implementation requires finding max prob and masking below threshold.
|
||||
return logits
|
||||
}
|
||||
63
pkg/mlx/slice.go
Normal file
63
pkg/mlx/slice.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// Slice extracts a sub-array using start and end indices for each dimension.
|
||||
// starts and ends must have the same length as the array's dimensions.
|
||||
func Slice(a *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE", a)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
cStarts[i] = C.int(starts[i])
|
||||
cEnds[i] = C.int(ends[i])
|
||||
}
|
||||
strides := make([]C.int, len(starts))
|
||||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice(&out.ctx, a.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// SliceAxis extracts a sub-array along a single axis.
|
||||
func SliceAxis(a *Array, axis int, start, end int32) *Array {
|
||||
// Build full slice parameters
|
||||
ndim := a.NumDims()
|
||||
starts := make([]int32, ndim)
|
||||
ends := make([]int32, ndim)
|
||||
for i := 0; i < ndim; i++ {
|
||||
starts[i] = 0
|
||||
ends[i] = int32(a.Dim(i))
|
||||
}
|
||||
ax := axis
|
||||
if ax < 0 {
|
||||
ax = ndim + ax
|
||||
}
|
||||
starts[ax] = start
|
||||
ends[ax] = end
|
||||
return Slice(a, starts, ends)
|
||||
}
|
||||
|
||||
// SliceUpdateInplace updates a slice of the array in-place.
|
||||
// This is critical for KV cache updates.
|
||||
func SliceUpdateInplace(a, update *Array, starts, ends []int32) *Array {
|
||||
out := New("SLICE_UPDATE", a, update)
|
||||
cStarts := make([]C.int, len(starts))
|
||||
cEnds := make([]C.int, len(ends))
|
||||
for i := range starts {
|
||||
cStarts[i] = C.int(starts[i])
|
||||
cEnds[i] = C.int(ends[i])
|
||||
}
|
||||
strides := make([]C.int, len(starts))
|
||||
for i := range strides {
|
||||
strides[i] = 1
|
||||
}
|
||||
C.mlx_slice_update(&out.ctx, a.ctx, update.ctx, &cStarts[0], C.int(len(cStarts)), &cEnds[0], C.int(len(cEnds)), &strides[0], C.int(len(strides)), DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
74
pkg/mlx/stream.go
Normal file
74
pkg/mlx/stream.go
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include "mlx/c/mlx.h"
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "sync"
|
||||
|
||||
// Stream wraps an mlx_stream handle for dispatching operations.
|
||||
type Stream struct {
|
||||
ctx C.mlx_stream
|
||||
}
|
||||
|
||||
var (
|
||||
defaultStream *Stream
|
||||
defaultStreamOnce sync.Once
|
||||
)
|
||||
|
||||
// DefaultStream returns the default GPU stream, creating it on first use.
|
||||
func DefaultStream() *Stream {
|
||||
defaultStreamOnce.Do(func() {
|
||||
Init()
|
||||
defaultStream = &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
})
|
||||
return defaultStream
|
||||
}
|
||||
|
||||
// DefaultGPUStream returns a new GPU stream.
|
||||
func DefaultGPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_gpu_stream_new()}
|
||||
}
|
||||
|
||||
// DefaultCPUStream returns a new CPU stream.
|
||||
func DefaultCPUStream() *Stream {
|
||||
Init()
|
||||
return &Stream{ctx: C.mlx_default_cpu_stream_new()}
|
||||
}
|
||||
|
||||
// Synchronize waits for all operations on the stream to complete.
|
||||
func Synchronize(s *Stream) {
|
||||
C.mlx_synchronize(s.ctx)
|
||||
}
|
||||
|
||||
// SetMemoryLimit sets the Metal memory limit. Returns the previous limit.
|
||||
func SetMemoryLimit(limit uint64) uint64 {
|
||||
var prev C.size_t
|
||||
C.mlx_set_memory_limit(&prev, C.size_t(limit))
|
||||
return uint64(prev)
|
||||
}
|
||||
|
||||
// SetCacheLimit sets the Metal cache limit. Returns the previous limit.
|
||||
func SetCacheLimit(limit uint64) uint64 {
|
||||
var prev C.size_t
|
||||
C.mlx_set_cache_limit(&prev, C.size_t(limit))
|
||||
return uint64(prev)
|
||||
}
|
||||
|
||||
// GetActiveMemory returns the current Metal memory usage in bytes.
|
||||
func GetActiveMemory() uint64 {
|
||||
var mem C.size_t
|
||||
C.mlx_get_active_memory(&mem)
|
||||
return uint64(mem)
|
||||
}
|
||||
|
||||
// GetPeakMemory returns the peak Metal memory usage in bytes.
|
||||
func GetPeakMemory() uint64 {
|
||||
var mem C.size_t
|
||||
C.mlx_get_peak_memory(&mem)
|
||||
return uint64(mem)
|
||||
}
|
||||
174
pkg/mlx/tokenizer/tokenizer.go
Normal file
174
pkg/mlx/tokenizer/tokenizer.go
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
//go:build darwin && arm64 && mlx
|
||||
|
||||
// Package tokenizer provides BPE/SentencePiece tokenization for Gemma models.
|
||||
package tokenizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tokenizer handles text-to-token and token-to-text conversion.
|
||||
type Tokenizer struct {
|
||||
vocab map[string]int32
|
||||
invVocab map[int32]string
|
||||
merges []mergePair
|
||||
special map[string]int32
|
||||
|
||||
bosToken int32
|
||||
eosToken int32
|
||||
}
|
||||
|
||||
type mergePair struct {
|
||||
a, b string
|
||||
rank int
|
||||
}
|
||||
|
||||
// tokenizerJSON is the HuggingFace tokenizer.json format.
|
||||
type tokenizerJSON struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab json.RawMessage `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
ByteFallback bool `json:"byte_fallback"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
ID int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
}
|
||||
|
||||
// Load reads a tokenizer.json file and creates a Tokenizer.
|
||||
func Load(path string) (*Tokenizer, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: read %s: %w", path, err)
|
||||
}
|
||||
|
||||
var tj tokenizerJSON
|
||||
if err := json.Unmarshal(data, &tj); err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: parse: %w", err)
|
||||
}
|
||||
|
||||
t := &Tokenizer{
|
||||
vocab: make(map[string]int32),
|
||||
invVocab: make(map[int32]string),
|
||||
special: make(map[string]int32),
|
||||
}
|
||||
|
||||
// Parse vocab
|
||||
var vocab map[string]int32
|
||||
if err := json.Unmarshal(tj.Model.Vocab, &vocab); err != nil {
|
||||
return nil, fmt.Errorf("tokenizer: parse vocab: %w", err)
|
||||
}
|
||||
t.vocab = vocab
|
||||
for k, v := range vocab {
|
||||
t.invVocab[v] = k
|
||||
}
|
||||
|
||||
// Parse merges
|
||||
for rank, merge := range tj.Model.Merges {
|
||||
parts := strings.SplitN(merge, " ", 2)
|
||||
if len(parts) == 2 {
|
||||
t.merges = append(t.merges, mergePair{a: parts[0], b: parts[1], rank: rank})
|
||||
}
|
||||
}
|
||||
|
||||
// Parse special tokens
|
||||
for _, tok := range tj.AddedTokens {
|
||||
if tok.Special {
|
||||
t.special[tok.Content] = tok.ID
|
||||
}
|
||||
t.vocab[tok.Content] = tok.ID
|
||||
t.invVocab[tok.ID] = tok.Content
|
||||
}
|
||||
|
||||
// Set BOS/EOS
|
||||
if id, ok := t.special["<bos>"]; ok {
|
||||
t.bosToken = id
|
||||
}
|
||||
if id, ok := t.special["<eos>"]; ok {
|
||||
t.eosToken = id
|
||||
}
|
||||
if id, ok := t.special["<end_of_turn>"]; ok {
|
||||
t.eosToken = id // Gemma uses end_of_turn as EOS
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// Encode converts text to token IDs. Prepends BOS token.
|
||||
func (t *Tokenizer) Encode(text string) []int32 {
|
||||
tokens := []int32{t.bosToken}
|
||||
|
||||
// Simple BPE encoding — split into characters then merge
|
||||
// This is a simplified version. Full implementation handles
|
||||
// Unicode, byte fallback, and efficient BPE merging.
|
||||
chars := []string{}
|
||||
for _, r := range text {
|
||||
s := string(r)
|
||||
if s == " " {
|
||||
s = "▁" // SentencePiece space marker
|
||||
}
|
||||
chars = append(chars, s)
|
||||
}
|
||||
|
||||
// Check for special tokens first
|
||||
remaining := text
|
||||
for remaining != "" {
|
||||
found := false
|
||||
for tok, id := range t.special {
|
||||
if strings.HasPrefix(remaining, tok) {
|
||||
tokens = append(tokens, id)
|
||||
remaining = remaining[len(tok):]
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
// Encode character by character (simplified BPE)
|
||||
r := []rune(remaining)
|
||||
ch := "▁" + string(r[0])
|
||||
if id, ok := t.vocab[ch]; ok {
|
||||
tokens = append(tokens, id)
|
||||
} else if id, ok := t.vocab[string(r[0])]; ok {
|
||||
tokens = append(tokens, id)
|
||||
}
|
||||
remaining = string(r[1:])
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
// Decode converts token IDs back to text.
|
||||
func (t *Tokenizer) Decode(tokens []int32) string {
|
||||
var sb strings.Builder
|
||||
for _, id := range tokens {
|
||||
if text, ok := t.invVocab[id]; ok {
|
||||
// Replace SentencePiece space marker
|
||||
text = strings.ReplaceAll(text, "▁", " ")
|
||||
sb.WriteString(text)
|
||||
}
|
||||
}
|
||||
result := sb.String()
|
||||
// Trim leading space from SentencePiece encoding
|
||||
if strings.HasPrefix(result, " ") {
|
||||
result = result[1:]
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// BOSToken returns the beginning-of-sequence token ID.
|
||||
func (t *Tokenizer) BOSToken() int32 { return t.bosToken }
|
||||
|
||||
// EOSToken returns the end-of-sequence token ID.
|
||||
func (t *Tokenizer) EOSToken() int32 { return t.eosToken }
|
||||
|
||||
// FormatGemmaPrompt applies the Gemma 3 chat template.
|
||||
func FormatGemmaPrompt(prompt string) string {
|
||||
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue