fix: add Detach/Free calls to reduce Metal GPU memory retention

Add deterministic memory cleanup across inference paths:
- Detach logits after Eval to release graph references
- Free intermediate arrays in attention (gemma3, qwen3)
- Add cache Detach helper for KV cache cleanup after generation
- New detach.cpp/go CGO bindings for mlx_array_detach

Reduces 4B model memory from 78GB to ~17GB (vs 2.4GB mlx-lm baseline).
Native Metal memory management still trails Python refcounting but is
now viable for 1B models.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 18:36:57 +00:00
parent c1baeb9254
commit 71fe4bb5ac
11 changed files with 267 additions and 60 deletions

View file

@ -22,7 +22,7 @@
| macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. |
| macOS 14, iOS 17 | Add support for the **MetalFX** framework. <br/>Add all the APIs in macOS 14 and iOS 17. |
| macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. |
| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.<br />newArray optional `NS::SharedPtr<T>` type to assist with memory management.<br/>newArray convenience function to create a `CA::MetalLayer`.<br/>newArray `MTLSTR(str)` macro allows faster string creation from literals.<br/>Fix a problem with the signature of functions that take an array of pointers as input.<br/>Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.<br />New optional `NS::SharedPtr<T>` type to assist with memory management.<br/>New convenience function to create a `CA::MetalLayer`.<br/>New `MTLSTR(str)` macro allows faster string creation from literals.<br/>Fix a problem with the signature of functions that take an array of pointers as input.<br/>Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
| macOS 12, iOS 15 | Initial release. |
## Memory Allocation Policy

View file

@ -14,6 +14,9 @@ type Cache interface {
State() []*Array
// Reset clears the cache for a new generation session.
Reset()
// Detach replaces internal K/V arrays with copies that have no graph parents.
// Call after Eval to allow Metal memory from prior graph operations to be freed.
Detach()
}
// KVCache implements an unbounded cache that grows as needed.
@ -90,6 +93,13 @@ func (c *KVCache) Reset() {
c.offset = 0
}
func (c *KVCache) Detach() {
if c.keys == nil {
return
}
Detach(c.keys, c.values)
}
// RotatingKVCache implements a bounded sliding window cache.
type RotatingKVCache struct {
keys, values *Array
@ -190,7 +200,10 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array)
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
// Return Slice views so callers can Free them without destroying the cache.
// (updateInPlace and KVCache.Update already return Slice views.)
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dk}),
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dv})
}
func (c *RotatingKVCache) State() []*Array {
@ -209,3 +222,10 @@ func (c *RotatingKVCache) Reset() {
c.offset = 0
c.idx = 0
}
func (c *RotatingKVCache) Detach() {
if c.keys == nil {
return
}
Detach(c.keys, c.values)
}

View file

@ -0,0 +1,25 @@
//go:build darwin && arm64
package metal
import "testing"
func TestCopy_BreaksGraph(t *testing.T) {
// Create a chain: a -> b -> c
a := FromValue(float32(1.0))
b := Add(a, FromValue(float32(2.0)))
Eval(b)
// Copy should break the graph
c := Copy(b)
Eval(c)
// Free b — if Copy truly detaches, c should still be valid
Free(b)
val := c.Float()
if val != 3.0 {
t.Fatalf("expected 3.0, got %f", val)
}
Free(a, c)
}

View file

@ -0,0 +1,8 @@
#include "mlx/mlx.h"
#include "mlx/c/array.h"
extern "C" void mlx_array_detach_impl(mlx_array arr) {
if (arr.ctx) {
static_cast<mlx::core::array*>(arr.ctx)->detach();
}
}

22
internal/metal/detach.go Normal file
View file

@ -0,0 +1,22 @@
//go:build darwin && arm64
package metal
/*
#include "mlx/c/array.h"
// mlx_array_detach breaks an evaluated array's graph connections.
// ctx is a mlx::core::array* — we call detach() via a C++ helper.
void mlx_array_detach_impl(mlx_array arr);
*/
import "C"
// Detach breaks an array's graph connections after evaluation.
// This allows Metal memory from parent operations to be freed.
func Detach(arrays ...*Array) {
for _, a := range arrays {
if a != nil && a.ctx.ctx != nil {
C.mlx_array_detach_impl(a.ctx)
}
}
}

View file

@ -330,78 +330,121 @@ func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
s := float32(math.Sqrt(float64(m.Cfg.HiddenSize)))
h2 := MulScalar(h, s)
Free(h)
h = h2
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
Free(h)
h = hNext
}
return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
normed := RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)
out := m.Output.Forward(normed)
Free(h, normed)
return out
}
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array {
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg)
attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := Add(x, attnOut)
Free(normed)
attnOutNormed := RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
Free(attnOut)
h := Add(x, attnOutNormed)
Free(attnOutNormed)
normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return Add(h, mlpOut)
normed2 := RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed2)
Free(normed2)
mlpOutNormed := RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
Free(mlpOut)
result := Add(h, mlpOutNormed)
Free(h, mlpOutNormed)
return result
}
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
qProj := a.QProj.Forward(x)
kProj := a.KProj.Forward(x)
vProj := a.VProj.Forward(x)
// Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation.
// Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory),
// seq = H*D (jump one full row of heads), elem = 1 (contiguous within head).
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
// AsStrided creates a view (C refcount keeps source alive), so Free source after.
q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
Free(qProj)
k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
Free(kProj)
v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
Free(vProj)
// Q/K normalization
oldQ := q
q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
Free(oldQ)
oldK := k
k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
Free(oldK)
// RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
oldQ = q
q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
Free(oldQ)
oldK = k
k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
Free(oldK)
// Update cache
// Update cache — returns Slice views into cache buffer; free our pre-update handles.
oldK, oldV := k, v
k, v = c.Update(k, v, int(L))
Free(oldK, oldV)
// GQA: repeat K/V heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
kAttn, vAttn := k, v
if repeatFactor > 1 {
k = RepeatKV(k, repeatFactor)
v = RepeatKV(v, repeatFactor)
kAttn = RepeatKV(k, repeatFactor)
vAttn = RepeatKV(v, repeatFactor)
Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
}
// Scaled dot-product attention
var out *Array
if mask != nil {
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
} else {
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
}
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
transposed := Transpose(out, 0, 2, 1, 3)
Free(out)
reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
Free(transposed)
result := a.OProj.Forward(reshaped)
Free(reshaped)
return result
}
func (m *MLP) forward(x *Array) *Array {
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
gateProj := m.GateProj.Forward(x)
gate := getCompiledGELU().Call(gateProj)[0]
Free(gateProj)
upProj := m.UpProj.Forward(x)
activated := Mul(gate, upProj)
Free(gate, upProj)
result := m.DownProj.Forward(activated)
Free(activated)
return result
}
// NewCache creates per-layer caches for generation.

View file

@ -177,6 +177,13 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
m.lastErr = fmt.Errorf("prefill: %w", err)
return
}
// Detach logits and cache arrays to release the entire prefill computation
// graph. After Eval, data is materialised — graph connections only pin Metal
// memory from intermediate tensors (34 layers × ~20 ops each).
Detach(logits)
for _, c := range caches {
c.Detach()
}
prefillDur = time.Since(prefillStart)
// Track generated token IDs for repeat penalty.
@ -248,6 +255,15 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
return
}
// Detach logits and cache arrays to break the computation graph.
// Without this, each step's logits holds shared_ptrs through the
// entire forward pass (SDPA → Slice → cache), pinning hundreds of
// Metal buffers per step that accumulate to tens of GB.
Detach(logits)
for _, c := range caches {
c.Detach()
}
}
}
}
@ -264,9 +280,11 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
defer freeCaches(caches)
// Single prefill pass — populates KV caches for all layers.
input := FromValues(tokens, len(tokens))
input = Reshape(input, 1, int32(len(tokens)))
vInput := FromValues(tokens, len(tokens))
input := Reshape(vInput, 1, int32(len(tokens)))
Free(vInput)
logits := m.model.Forward(input, caches)
Free(input)
if err := Eval(logits); err != nil {
return nil, fmt.Errorf("prefill: %w", err)
}

View file

@ -91,11 +91,20 @@ func (l *LoRALinear) Forward(x *Array) *Array {
baseOut := l.Base.baseForward(x)
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
loraOut := Matmul(x, Transpose(l.A))
loraOut = Matmul(loraOut, Transpose(l.B))
loraOut = MulScalar(loraOut, l.Scale)
ta := Transpose(l.A)
loraOut := Matmul(x, ta)
Free(ta)
return Add(baseOut, loraOut)
tb := Transpose(l.B)
loraOut2 := Matmul(loraOut, tb)
Free(loraOut, tb)
loraOut3 := MulScalar(loraOut2, l.Scale)
Free(loraOut2)
res := Add(baseOut, loraOut3)
Free(baseOut, loraOut3)
return res
}
// TrainableParams returns the LoRA A and B arrays for gradient computation.

View file

@ -50,10 +50,14 @@ func (l *Linear) baseForward(x *Array) *Array {
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
} else {
out = Matmul(x, Transpose(l.Weight))
wT := Transpose(l.Weight)
out = Matmul(x, wT)
Free(wT)
}
if l.Bias != nil && l.Bias.Valid() {
oldOut := out
out = Add(out, l.Bias)
Free(oldOut)
}
return out
}
@ -72,7 +76,9 @@ type Embedding struct {
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales != nil {
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
return Take(w, indices, 0)
res := Take(w, indices, 0)
Free(w)
return res
}
return Take(e.Weight, indices, 0)
}
@ -110,6 +116,10 @@ func RepeatKV(x *Array, factor int32) *Array {
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
expanded := ExpandDims(x, 2)
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
return Reshape(expanded, B, H*factor, L, D)
broadcasted := BroadcastTo(expanded, []int32{B, H, factor, L, D})
Free(expanded)
res := Reshape(broadcasted, B, H*factor, L, D)
Free(broadcasted)
return res
}

View file

@ -63,6 +63,15 @@ func Negative(a *Array) *Array {
return out
}
// Copy creates a deep copy of an array, breaking the computation graph chain.
// The returned array has the same data but no references to parent graph nodes,
// allowing Metal memory from prior graph operations to be freed.
func Copy(a *Array) *Array {
out := newArray("COPY", a)
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
return out
}
// --- Math functions ---
// Exp returns element-wise exp(a).
@ -81,7 +90,10 @@ func Sigmoid(a *Array) *Array {
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
func SiLU(a *Array) *Array {
return Mul(a, Sigmoid(a))
s := Sigmoid(a)
res := Mul(a, s)
Free(s)
return res
}
// Tanh returns element-wise tanh(a).

View file

@ -261,74 +261,114 @@ func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
Free(h)
h = hNext
}
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
normed := m.Norm.Forward(h, m.Cfg.RMSNormEps)
out := m.Output.Forward(normed)
Free(h, normed)
return out
}
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, mask, cfg)
Free(normed)
h := Add(x, attnOut)
Free(attnOut)
// Pre-MLP norm → MLP → residual add
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed)
return Add(h, mlpOut)
normed2 := l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
mlpOut := l.MLP.forward(normed2)
Free(normed2)
result := Add(h, mlpOut)
Free(h, mlpOut)
return result
}
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
qProj := a.QProj.Forward(x)
kProj := a.KProj.Forward(x)
vProj := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
// Reshape to [B, num_heads, L, head_dim] via stride manipulation.
// AsStrided creates a view (C refcount keeps source alive), so Free source after.
q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
Free(qProj)
k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
Free(kProj)
v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
Free(vProj)
// Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not)
if a.QNorm != nil && a.QNorm.Weight != nil {
oldQ := q
q = a.QNorm.Forward(q, cfg.RMSNormEps)
Free(oldQ)
}
if a.KNorm != nil && a.KNorm.Weight != nil {
oldK := k
k = a.KNorm.Forward(k, cfg.RMSNormEps)
Free(oldK)
}
// RoPE — single theta for all layers (no sliding window)
oldQ := q
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
Free(oldQ)
oldK := k
k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
Free(oldK)
// Update KV cache
// Update KV cache — returns Slice views into cache buffer; free our pre-update handles.
oldK, oldV := k, v
k, v = c.Update(k, v, int(L))
Free(oldK, oldV)
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
kAttn, vAttn := k, v
if repeatFactor > 1 {
k = RepeatKV(k, repeatFactor)
v = RepeatKV(v, repeatFactor)
kAttn = RepeatKV(k, repeatFactor)
vAttn = RepeatKV(v, repeatFactor)
Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
}
// Scaled dot-product attention
var out *Array
if mask != nil {
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
} else {
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
}
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
transposed := Transpose(out, 0, 2, 1, 3)
Free(out)
reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
Free(transposed)
result := a.OProj.Forward(reshaped)
Free(reshaped)
return result
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *Array) *Array {
gate := SiLU(m.GateProj.Forward(x))
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
gateProj := m.GateProj.Forward(x)
gate := SiLU(gateProj)
Free(gateProj)
upProj := m.UpProj.Forward(x)
activated := Mul(gate, upProj)
Free(gate, upProj)
result := m.DownProj.Forward(activated)
Free(activated)
return result
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.