- Fix remaining 187 pkg/ files referencing core/cli → core/go - Move SDK library code from internal/cmd/sdk/ → pkg/sdk/ (new package) - Create pkg/rag/helpers.go with convenience functions from internal/cmd/rag/ - Fix pkg/mcp/tools_rag.go to use pkg/rag instead of internal/cmd/rag - Fix pkg/build/buildcmd/cmd_sdk.go and pkg/release/sdk.go to use pkg/sdk - Remove all non-library content: main.go, internal/, cmd/, docker/, scripts/, tasks/, tools/, .core/, .forgejo/, .woodpecker/, Taskfile.yml - Run go mod tidy to trim unused dependencies core/go is now a pure Go package suite (library only). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Claude <developers@lethean.io> Reviewed-on: #3
234 lines
5.9 KiB
Go
234 lines
5.9 KiB
Go
//go:build darwin && arm64 && mlx
|
||
|
||
package ml
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"log/slog"
|
||
"runtime"
|
||
"sync"
|
||
|
||
"forge.lthn.ai/core/go/pkg/mlx"
|
||
"forge.lthn.ai/core/go/pkg/mlx/cache"
|
||
"forge.lthn.ai/core/go/pkg/mlx/model"
|
||
"forge.lthn.ai/core/go/pkg/mlx/sample"
|
||
"forge.lthn.ai/core/go/pkg/mlx/tokenizer"
|
||
)
|
||
|
||
// MLXBackend implements Backend for native Metal inference via mlx-c.
|
||
type MLXBackend struct {
|
||
model *model.GemmaModel
|
||
tok *tokenizer.Tokenizer
|
||
caches []cache.Cache
|
||
sampler sample.Sampler
|
||
mu sync.Mutex
|
||
modelBytes uint64 // model size at load time, for memory budget
|
||
}
|
||
|
||
// NewMLXBackend loads a model from a safetensors directory and creates
|
||
// a native Metal inference backend.
|
||
func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||
if !mlx.MetalAvailable() {
|
||
return nil, fmt.Errorf("mlx: Metal GPU not available")
|
||
}
|
||
|
||
slog.Info("mlx: loading model", "path", modelPath)
|
||
m, err := model.LoadGemma3(modelPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||
}
|
||
|
||
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
||
// This prevents runaway memory growth from killing the system.
|
||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||
|
||
modelMB := mlx.GetActiveMemory() / 1024 / 1024
|
||
slog.Info("mlx: model loaded",
|
||
"layers", m.NumLayers(),
|
||
"memory_mb", modelMB,
|
||
)
|
||
|
||
return &MLXBackend{
|
||
model: m,
|
||
tok: m.Tokenizer(),
|
||
caches: m.NewCache(),
|
||
sampler: sample.New(0.1, 0, 0, 0), // default low temp
|
||
modelBytes: mlx.GetActiveMemory(),
|
||
}, nil
|
||
}
|
||
|
||
// Generate produces text from a prompt using native Metal inference.
|
||
func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts) (string, error) {
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
|
||
// Reset caches for new generation
|
||
for _, c := range b.caches {
|
||
c.Reset()
|
||
}
|
||
|
||
// Set up sampler based on opts
|
||
temp := float32(opts.Temperature)
|
||
if temp == 0 {
|
||
temp = 0.1
|
||
}
|
||
sampler := sample.New(temp, 0, 0, 0)
|
||
|
||
// Tokenize
|
||
formatted := tokenizer.FormatGemmaPrompt(prompt)
|
||
tokens := b.tok.Encode(formatted)
|
||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||
|
||
maxTokens := opts.MaxTokens
|
||
if maxTokens == 0 {
|
||
maxTokens = 2048
|
||
}
|
||
|
||
// Generation loop — force Go GC every 4 tokens so finalizers release
|
||
// intermediate C array handles that Go GC cannot see as memory pressure.
|
||
var output []int32
|
||
for i := 0; i < maxTokens; i++ {
|
||
select {
|
||
case <-ctx.Done():
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
return b.tok.Decode(output), ctx.Err()
|
||
default:
|
||
}
|
||
|
||
logits := b.model.Forward(input, b.caches)
|
||
logits = lastPosition(logits)
|
||
next := sampler.Sample(logits)
|
||
mlx.Materialize(next)
|
||
|
||
nextToken := int32(next.Int())
|
||
if nextToken == b.tok.EOSToken() {
|
||
break
|
||
}
|
||
output = append(output, nextToken)
|
||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||
|
||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||
if i%4 == 3 {
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
}
|
||
}
|
||
|
||
// Cleanup between requests
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
b.checkMemory()
|
||
return b.tok.Decode(output), nil
|
||
}
|
||
|
||
// lastPosition extracts the last sequence position from [B, L, V] logits → [B, V].
|
||
func lastPosition(logits *mlx.Array) *mlx.Array {
|
||
shape := logits.Shape()
|
||
if len(shape) == 3 && shape[1] > 1 {
|
||
L := shape[1]
|
||
logits = mlx.Slice(logits, []int32{0, L - 1, 0}, []int32{shape[0], L, shape[2]})
|
||
logits = mlx.Reshape(logits, shape[0], shape[2])
|
||
} else if len(shape) == 3 && shape[1] == 1 {
|
||
logits = mlx.Reshape(logits, shape[0], shape[2])
|
||
}
|
||
return logits
|
||
}
|
||
|
||
// Chat formats messages and generates a response.
|
||
func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts) (string, error) {
|
||
// Format as Gemma chat
|
||
var prompt string
|
||
for _, msg := range messages {
|
||
switch msg.Role {
|
||
case "user":
|
||
prompt += fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n", msg.Content)
|
||
case "assistant":
|
||
prompt += fmt.Sprintf("<start_of_turn>model\n%s<end_of_turn>\n", msg.Content)
|
||
case "system":
|
||
prompt += fmt.Sprintf("<start_of_turn>user\n[System: %s]<end_of_turn>\n", msg.Content)
|
||
}
|
||
}
|
||
prompt += "<start_of_turn>model\n"
|
||
|
||
// Use raw prompt (already formatted)
|
||
b.mu.Lock()
|
||
defer b.mu.Unlock()
|
||
|
||
for _, c := range b.caches {
|
||
c.Reset()
|
||
}
|
||
|
||
temp := float32(opts.Temperature)
|
||
if temp == 0 {
|
||
temp = 0.1
|
||
}
|
||
sampler := sample.New(temp, 0, 0, 0)
|
||
|
||
tokens := b.tok.Encode(prompt)
|
||
input := mlx.FromValues(tokens, 1, len(tokens))
|
||
|
||
maxTokens := opts.MaxTokens
|
||
if maxTokens == 0 {
|
||
maxTokens = 2048
|
||
}
|
||
|
||
var output []int32
|
||
for i := 0; i < maxTokens; i++ {
|
||
select {
|
||
case <-ctx.Done():
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
return b.tok.Decode(output), ctx.Err()
|
||
default:
|
||
}
|
||
|
||
logits := b.model.Forward(input, b.caches)
|
||
logits = lastPosition(logits)
|
||
next := sampler.Sample(logits)
|
||
mlx.Materialize(next)
|
||
|
||
nextToken := int32(next.Int())
|
||
if nextToken == b.tok.EOSToken() {
|
||
break
|
||
}
|
||
output = append(output, nextToken)
|
||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||
|
||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||
if i%4 == 3 {
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
}
|
||
}
|
||
|
||
// Cleanup between requests
|
||
runtime.GC()
|
||
mlx.ClearCache()
|
||
b.checkMemory()
|
||
return b.tok.Decode(output), nil
|
||
}
|
||
|
||
// checkMemory logs Metal memory usage and forces cleanup if it exceeds budget.
|
||
func (b *MLXBackend) checkMemory() {
|
||
active := mlx.GetActiveMemory()
|
||
budget := b.modelBytes * 3 // 3× model size = danger zone
|
||
if active > budget {
|
||
slog.Warn("mlx: memory over budget, forcing cleanup",
|
||
"active_mb", active/1024/1024,
|
||
"model_mb", b.modelBytes/1024/1024,
|
||
"peak_mb", mlx.GetPeakMemory()/1024/1024,
|
||
)
|
||
runtime.GC()
|
||
runtime.GC() // double GC to run finalizers
|
||
mlx.ClearCache()
|
||
}
|
||
}
|
||
|
||
// Name returns the backend identifier.
|
||
func (b *MLXBackend) Name() string { return "mlx" }
|
||
|
||
// Available reports whether Metal GPU is ready.
|
||
func (b *MLXBackend) Available() bool { return mlx.MetalAvailable() }
|