fix: add GC-based memory management for MLX array handles
Go GC cannot see Metal/C memory pressure, so intermediate arrays from each forward pass accumulated without bound, causing OOM kills after 3-4 requests. Fix: runtime.SetFinalizer on every Array releases C handles when GC collects them, and runtime.GC() is forced every 4 tokens during generation. Also adds SetMemoryLimit(24GB) as a hard Metal ceiling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9688e086ca
commit
8cdafc8d66
2 changed files with 33 additions and 7 deletions
|
|
@ -6,6 +6,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"runtime"
|
||||
"sync"
|
||||
|
||||
"forge.lthn.ai/core/cli/pkg/mlx"
|
||||
|
|
@ -37,8 +38,10 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
|||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||
}
|
||||
|
||||
// Set Metal cache limit to prevent unbounded memory growth
|
||||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB
|
||||
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
||||
// This prevents runaway memory growth from killing the system.
|
||||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||||
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||||
|
||||
slog.Info("mlx: model loaded",
|
||||
"layers", m.NumLayers(),
|
||||
|
|
@ -80,11 +83,13 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
maxTokens = 2048
|
||||
}
|
||||
|
||||
// Generation loop
|
||||
// Generation loop — force Go GC every 4 tokens so finalizers release
|
||||
// intermediate C array handles that Go GC cannot see as memory pressure.
|
||||
var output []int32
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), ctx.Err()
|
||||
default:
|
||||
|
|
@ -102,12 +107,15 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
|||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
|
||||
// Periodically release Metal allocator cache to prevent memory growth
|
||||
if i%8 == 7 {
|
||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||
if i%4 == 3 {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Full cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
|
@ -167,6 +175,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
|||
for i := 0; i < maxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), ctx.Err()
|
||||
default:
|
||||
|
|
@ -184,11 +193,15 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
|||
output = append(output, nextToken)
|
||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||
|
||||
if i%8 == 7 {
|
||||
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||
if i%4 == 3 {
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
// Full cleanup between requests
|
||||
runtime.GC()
|
||||
mlx.ClearCache()
|
||||
return b.tok.Decode(output), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import "C"
|
|||
import (
|
||||
"encoding/binary"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
|
@ -28,6 +29,9 @@ type Array struct {
|
|||
}
|
||||
|
||||
// New creates a named Array tracking its input dependencies for cleanup.
|
||||
// A runtime finalizer is set so Go GC can release the C handle when
|
||||
// the Array becomes unreachable — critical because Go GC cannot see
|
||||
// Metal/C memory pressure.
|
||||
func New(name string, inputs ...*Array) *Array {
|
||||
t := &Array{
|
||||
desc: tensorDesc{
|
||||
|
|
@ -40,9 +44,18 @@ func New(name string, inputs ...*Array) *Array {
|
|||
input.desc.numRefs++
|
||||
}
|
||||
}
|
||||
runtime.SetFinalizer(t, finalizeArray)
|
||||
return t
|
||||
}
|
||||
|
||||
// finalizeArray is called by Go GC to release the underlying C array handle.
|
||||
func finalizeArray(t *Array) {
|
||||
if t != nil && t.ctx.ctx != nil {
|
||||
C.mlx_array_free(t.ctx)
|
||||
t.ctx.ctx = nil
|
||||
}
|
||||
}
|
||||
|
||||
type scalarTypes interface {
|
||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||
}
|
||||
|
|
@ -50,7 +63,7 @@ type scalarTypes interface {
|
|||
// FromValue creates a scalar Array from a Go value.
|
||||
func FromValue[T scalarTypes](t T) *Array {
|
||||
Init()
|
||||
tt := New("")
|
||||
tt := New("") // finalizer set by New
|
||||
switch v := any(t).(type) {
|
||||
case bool:
|
||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue