go/pkg/ml/backend_mlx.go
Snider adaa4131f9 refactor: strip to pure package library (#3)
- Fix remaining 187 pkg/ files referencing core/cli → core/go
- Move SDK library code from internal/cmd/sdk/ → pkg/sdk/ (new package)
- Create pkg/rag/helpers.go with convenience functions from internal/cmd/rag/
- Fix pkg/mcp/tools_rag.go to use pkg/rag instead of internal/cmd/rag
- Fix pkg/build/buildcmd/cmd_sdk.go and pkg/release/sdk.go to use pkg/sdk
- Remove all non-library content: main.go, internal/, cmd/, docker/,
  scripts/, tasks/, tools/, .core/, .forgejo/, .woodpecker/, Taskfile.yml
- Run go mod tidy to trim unused dependencies

core/go is now a pure Go package suite (library only).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Co-authored-by: Claude <developers@lethean.io>
Reviewed-on: #3
2026-02-16 14:23:45 +00:00

234 lines
5.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build darwin && arm64 && mlx
package ml
import (
"context"
"fmt"
"log/slog"
"runtime"
"sync"
"forge.lthn.ai/core/go/pkg/mlx"
"forge.lthn.ai/core/go/pkg/mlx/cache"
"forge.lthn.ai/core/go/pkg/mlx/model"
"forge.lthn.ai/core/go/pkg/mlx/sample"
"forge.lthn.ai/core/go/pkg/mlx/tokenizer"
)
// MLXBackend implements Backend for native Metal inference via mlx-c.
type MLXBackend struct {
model *model.GemmaModel
tok *tokenizer.Tokenizer
caches []cache.Cache
sampler sample.Sampler
mu sync.Mutex
modelBytes uint64 // model size at load time, for memory budget
}
// NewMLXBackend loads a model from a safetensors directory and creates
// a native Metal inference backend.
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
if !mlx.MetalAvailable() {
return nil, fmt.Errorf("mlx: Metal GPU not available")
}
slog.Info("mlx: loading model", "path", modelPath)
m, err := model.LoadGemma3(modelPath)
if err != nil {
return nil, fmt.Errorf("mlx: load model: %w", err)
}
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
// This prevents runaway memory growth from killing the system.
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
modelMB := mlx.GetActiveMemory() / 1024 / 1024
slog.Info("mlx: model loaded",
"layers", m.NumLayers(),
"memory_mb", modelMB,
)
return &MLXBackend{
model: m,
tok: m.Tokenizer(),
caches: m.NewCache(),
sampler: sample.New(0.1, 0, 0, 0), // default low temp
modelBytes: mlx.GetActiveMemory(),
}, nil
}
// Generate produces text from a prompt using native Metal inference.
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
b.mu.Lock()
defer b.mu.Unlock()
// Reset caches for new generation
for _, c := range b.caches {
c.Reset()
}
// Set up sampler based on opts
temp := float32(opts.Temperature)
if temp == 0 {
temp = 0.1
}
sampler := sample.New(temp, 0, 0, 0)
// Tokenize
formatted := tokenizer.FormatGemmaPrompt(prompt)
tokens := b.tok.Encode(formatted)
input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens
if maxTokens == 0 {
maxTokens = 2048
}
// Generation loop — force Go GC every 4 tokens so finalizers release
// intermediate C array handles that Go GC cannot see as memory pressure.
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), ctx.Err()
default:
}
logits := b.model.Forward(input, b.caches)
logits = lastPosition(logits)
next := sampler.Sample(logits)
mlx.Materialize(next)
nextToken := int32(next.Int())
if nextToken == b.tok.EOSToken() {
break
}
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache()
}
}
// Cleanup between requests
runtime.GC()
mlx.ClearCache()
b.checkMemory()
return b.tok.Decode(output), nil
}
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
func lastPosition(logits *mlx.Array) *mlx.Array {
shape := logits.Shape()
if len(shape) == 3 && shape[1] > 1 {
L := shape[1]
logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]})
logits = mlx.Reshape(logits, shape[0], shape[2])
} else if len(shape) == 3 && shape[1] == 1 {
logits = mlx.Reshape(logits, shape[0], shape[2])
}
return logits
}
// Chat formats messages and generates a response.
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
// Format as Gemma chat
var prompt string
for _, msg := range messages {
switch msg.Role {
case "user":
prompt += fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content)
case "assistant":
prompt += fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content)
case "system":
prompt += fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content)
}
}
prompt += "<start_of_turn>model\n"
// Use raw prompt (already formatted)
b.mu.Lock()
defer b.mu.Unlock()
for _, c := range b.caches {
c.Reset()
}
temp := float32(opts.Temperature)
if temp == 0 {
temp = 0.1
}
sampler := sample.New(temp, 0, 0, 0)
tokens := b.tok.Encode(prompt)
input := mlx.FromValues(tokens, 1, len(tokens))
maxTokens := opts.MaxTokens
if maxTokens == 0 {
maxTokens = 2048
}
var output []int32
for i := 0; i < maxTokens; i++ {
select {
case <-ctx.Done():
runtime.GC()
mlx.ClearCache()
return b.tok.Decode(output), ctx.Err()
default:
}
logits := b.model.Forward(input, b.caches)
logits = lastPosition(logits)
next := sampler.Sample(logits)
mlx.Materialize(next)
nextToken := int32(next.Int())
if nextToken == b.tok.EOSToken() {
break
}
output = append(output, nextToken)
input = mlx.FromValues([]int32{nextToken}, 1, 1)
// Force GC to collect intermediate arrays + release Metal allocator cache
if i%4 == 3 {
runtime.GC()
mlx.ClearCache()
}
}
// Cleanup between requests
runtime.GC()
mlx.ClearCache()
b.checkMemory()
return b.tok.Decode(output), nil
}
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
func (b *MLXBackend) checkMemory() {
active := mlx.GetActiveMemory()
budget := b.modelBytes * 3 // 3× model size = danger zone
if active > budget {
slog.Warn("mlx: memory over budget, forcing cleanup",
"active_mb", active/1024/1024,
"model_mb", b.modelBytes/1024/1024,
"peak_mb", mlx.GetPeakMemory()/1024/1024,
)
runtime.GC()
runtime.GC() // double GC to run finalizers
mlx.ClearCache()
}
}
// Name returns the backend identifier.
func (b *MLXBackend) Name() string { return "mlx" }
// Available reports whether Metal GPU is ready.
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }