fix: add GC-based memory management for MLX array handles
Go GC cannot see Metal/C memory pressure, so intermediate arrays from each forward pass accumulated without bound, causing OOM kills after 3-4 requests. Fix: runtime.SetFinalizer on every Array releases C handles when GC collects them, and runtime.GC() is forced every 4 tokens during generation. Also adds SetMemoryLimit(24GB) as a hard Metal ceiling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e6ada25bd8
commit
1d4ec55d05
2 changed files with 33 additions and 7 deletions
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"forge.lthn.ai/core/cli/pkg/mlx"
|
"forge.lthn.ai/core/cli/pkg/mlx"
|
||||||
|
|
@ -37,8 +38,10 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set Metal cache limit to prevent unbounded memory growth
|
// Cap Metal memory: cache limit for allocator reuse, memory limit as hard ceiling.
|
||||||
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB
|
// This prevents runaway memory growth from killing the system.
|
||||||
|
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB allocator cache
|
||||||
|
mlx.SetMemoryLimit(24 * 1024 * 1024 * 1024) // 24 GB hard cap
|
||||||
|
|
||||||
slog.Info("mlx: model loaded",
|
slog.Info("mlx: model loaded",
|
||||||
"layers", m.NumLayers(),
|
"layers", m.NumLayers(),
|
||||||
|
|
@ -80,11 +83,13 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
maxTokens = 2048
|
maxTokens = 2048
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generation loop
|
// Generation loop — force Go GC every 4 tokens so finalizers release
|
||||||
|
// intermediate C array handles that Go GC cannot see as memory pressure.
|
||||||
var output []int32
|
var output []int32
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), ctx.Err()
|
return b.tok.Decode(output), ctx.Err()
|
||||||
default:
|
default:
|
||||||
|
|
@ -102,12 +107,15 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
output = append(output, nextToken)
|
output = append(output, nextToken)
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||||
|
|
||||||
// Periodically release Metal allocator cache to prevent memory growth
|
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||||
if i%8 == 7 {
|
if i%4 == 3 {
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Full cleanup between requests
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), nil
|
return b.tok.Decode(output), nil
|
||||||
}
|
}
|
||||||
|
|
@ -167,6 +175,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), ctx.Err()
|
return b.tok.Decode(output), ctx.Err()
|
||||||
default:
|
default:
|
||||||
|
|
@ -184,11 +193,15 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
||||||
output = append(output, nextToken)
|
output = append(output, nextToken)
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||||
|
|
||||||
if i%8 == 7 {
|
// Force GC to collect intermediate arrays + release Metal allocator cache
|
||||||
|
if i%4 == 3 {
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Full cleanup between requests
|
||||||
|
runtime.GC()
|
||||||
mlx.ClearCache()
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), nil
|
return b.tok.Decode(output), nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ import "C"
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
@ -28,6 +29,9 @@ type Array struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a named Array tracking its input dependencies for cleanup.
|
// New creates a named Array tracking its input dependencies for cleanup.
|
||||||
|
// A runtime finalizer is set so Go GC can release the C handle when
|
||||||
|
// the Array becomes unreachable — critical because Go GC cannot see
|
||||||
|
// Metal/C memory pressure.
|
||||||
func New(name string, inputs ...*Array) *Array {
|
func New(name string, inputs ...*Array) *Array {
|
||||||
t := &Array{
|
t := &Array{
|
||||||
desc: tensorDesc{
|
desc: tensorDesc{
|
||||||
|
|
@ -40,9 +44,18 @@ func New(name string, inputs ...*Array) *Array {
|
||||||
input.desc.numRefs++
|
input.desc.numRefs++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
runtime.SetFinalizer(t, finalizeArray)
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finalizeArray is called by Go GC to release the underlying C array handle.
|
||||||
|
func finalizeArray(t *Array) {
|
||||||
|
if t != nil && t.ctx.ctx != nil {
|
||||||
|
C.mlx_array_free(t.ctx)
|
||||||
|
t.ctx.ctx = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type scalarTypes interface {
|
type scalarTypes interface {
|
||||||
~bool | ~int | ~float32 | ~float64 | ~complex64
|
~bool | ~int | ~float32 | ~float64 | ~complex64
|
||||||
}
|
}
|
||||||
|
|
@ -50,7 +63,7 @@ type scalarTypes interface {
|
||||||
// FromValue creates a scalar Array from a Go value.
|
// FromValue creates a scalar Array from a Go value.
|
||||||
func FromValue[T scalarTypes](t T) *Array {
|
func FromValue[T scalarTypes](t T) *Array {
|
||||||
Init()
|
Init()
|
||||||
tt := New("")
|
tt := New("") // finalizer set by New
|
||||||
switch v := any(t).(type) {
|
switch v := any(t).(type) {
|
||||||
case bool:
|
case bool:
|
||||||
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
tt.ctx = C.mlx_array_new_bool(C.bool(v))
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue