fix: add Metal cache management to prevent memory growth
- Add ClearCache() wrapping mlx_clear_cache - Clear Metal allocator cache every 8 tokens during generation - Set 16GB cache limit on backend init - Prevents GPU memory from growing unbounded during inference Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
098f496364
commit
9688e086ca
2 changed files with 21 additions and 1 deletions
|
|
@ -37,6 +37,9 @@ func NewMLXBackend(modelPath string) (*MLXBackend, error) {
|
||||||
return nil, fmt.Errorf("mlx: load model: %w", err)
|
return nil, fmt.Errorf("mlx: load model: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set Metal cache limit to prevent unbounded memory growth
|
||||||
|
mlx.SetCacheLimit(16 * 1024 * 1024 * 1024) // 16 GB
|
||||||
|
|
||||||
slog.Info("mlx: model loaded",
|
slog.Info("mlx: model loaded",
|
||||||
"layers", m.NumLayers(),
|
"layers", m.NumLayers(),
|
||||||
"memory_mb", mlx.GetActiveMemory()/1024/1024,
|
"memory_mb", mlx.GetActiveMemory()/1024/1024,
|
||||||
|
|
@ -82,12 +85,12 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), ctx.Err()
|
return b.tok.Decode(output), ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
logits := b.model.Forward(input, b.caches)
|
logits := b.model.Forward(input, b.caches)
|
||||||
// Take last position: [B, L, V] → [B, V]
|
|
||||||
logits = lastPosition(logits)
|
logits = lastPosition(logits)
|
||||||
next := sampler.Sample(logits)
|
next := sampler.Sample(logits)
|
||||||
mlx.Materialize(next)
|
mlx.Materialize(next)
|
||||||
|
|
@ -98,8 +101,14 @@ func (b *MLXBackend) Generate(ctx context.Context, prompt string, opts GenOpts)
|
||||||
}
|
}
|
||||||
output = append(output, nextToken)
|
output = append(output, nextToken)
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||||
|
|
||||||
|
// Periodically release Metal allocator cache to prevent memory growth
|
||||||
|
if i%8 == 7 {
|
||||||
|
mlx.ClearCache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), nil
|
return b.tok.Decode(output), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -158,6 +167,7 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
||||||
for i := 0; i < maxTokens; i++ {
|
for i := 0; i < maxTokens; i++ {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), ctx.Err()
|
return b.tok.Decode(output), ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
@ -173,8 +183,13 @@ func (b *MLXBackend) Chat(ctx context.Context, messages []Message, opts GenOpts)
|
||||||
}
|
}
|
||||||
output = append(output, nextToken)
|
output = append(output, nextToken)
|
||||||
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
input = mlx.FromValues([]int32{nextToken}, 1, 1)
|
||||||
|
|
||||||
|
if i%8 == 7 {
|
||||||
|
mlx.ClearCache()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mlx.ClearCache()
|
||||||
return b.tok.Decode(output), nil
|
return b.tok.Decode(output), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,3 +72,8 @@ func GetPeakMemory() uint64 {
|
||||||
C.mlx_get_peak_memory(&mem)
|
C.mlx_get_peak_memory(&mem)
|
||||||
return uint64(mem)
|
return uint64(mem)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClearCache releases Metal memory held in the MLX allocator cache.
|
||||||
|
func ClearCache() {
|
||||||
|
C.mlx_clear_cache()
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue