From 71fe4bb5aca1b1fc55b9681f12ee30baece1b517 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 23 Feb 2026 18:36:57 +0000 Subject: [PATCH] fix: add Detach/Free calls to reduce Metal GPU memory retention Add deterministic memory cleanup across inference paths: - Detach logits after Eval to release graph references - Free intermediate arrays in attention (gemma3, qwen3) - Add cache Detach helper for KV cache cleanup after generation - New detach.cpp/go CGO bindings for mlx_array_detach Reduces 4B model memory from 78GB to ~17GB (vs 2.4GB mlx-lm baseline). Native Metal memory management still trails Python refcounting but is now viable for 1B models. Co-Authored-By: Virgil --- dist/include/metal_cpp/README.md | 2 +- internal/metal/cache.go | 22 +++++++- internal/metal/copy_test.go | 25 +++++++++ internal/metal/detach.cpp | 8 +++ internal/metal/detach.go | 22 ++++++++ internal/metal/gemma3.go | 95 +++++++++++++++++++++++--------- internal/metal/generate.go | 22 +++++++- internal/metal/lora.go | 17 ++++-- internal/metal/nn.go | 18 ++++-- internal/metal/ops.go | 14 ++++- internal/metal/qwen3.go | 82 ++++++++++++++++++++------- 11 files changed, 267 insertions(+), 60 deletions(-) create mode 100644 internal/metal/copy_test.go create mode 100644 internal/metal/detach.cpp create mode 100644 internal/metal/detach.go diff --git a/dist/include/metal_cpp/README.md b/dist/include/metal_cpp/README.md index 52ae7b5..03d628c 100644 --- a/dist/include/metal_cpp/README.md +++ b/dist/include/metal_cpp/README.md @@ -22,7 +22,7 @@ | macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. | | macOS 14, iOS 17 | Add support for the **MetalFX** framework.
Add all the APIs in macOS 14 and iOS 17. | | macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. | -| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
newArray optional `NS::SharedPtr` type to assist with memory management.
newArray convenience function to create a `CA::MetalLayer`.
newArray `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.| +| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
New optional `NS::SharedPtr` type to assist with memory management.
New convenience function to create a `CA::MetalLayer`.
New `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.| | macOS 12, iOS 15 | Initial release. | ## Memory Allocation Policy diff --git a/internal/metal/cache.go b/internal/metal/cache.go index 9f5f581..d4c7a7c 100644 --- a/internal/metal/cache.go +++ b/internal/metal/cache.go @@ -14,6 +14,9 @@ type Cache interface { State() []*Array // Reset clears the cache for a new generation session. Reset() + // Detach replaces internal K/V arrays with copies that have no graph parents. + // Call after Eval to allow Metal memory from prior graph operations to be freed. + Detach() } // KVCache implements an unbounded cache that grows as needed. @@ -90,6 +93,13 @@ func (c *KVCache) Reset() { c.offset = 0 } +func (c *KVCache) Detach() { + if c.keys == nil { + return + } + Detach(c.keys, c.values) +} + // RotatingKVCache implements a bounded sliding window cache. type RotatingKVCache struct { keys, values *Array @@ -190,7 +200,10 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array) } c.idx = int(c.keys.Shape()[2]) - return c.keys, c.values + // Return Slice views so callers can Free them without destroying the cache. + // (updateInPlace and KVCache.Update already return Slice views.) + return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dk}), + Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dv}) } func (c *RotatingKVCache) State() []*Array { @@ -209,3 +222,10 @@ func (c *RotatingKVCache) Reset() { c.offset = 0 c.idx = 0 } + +func (c *RotatingKVCache) Detach() { + if c.keys == nil { + return + } + Detach(c.keys, c.values) +} diff --git a/internal/metal/copy_test.go b/internal/metal/copy_test.go new file mode 100644 index 0000000..3fdb5d1 --- /dev/null +++ b/internal/metal/copy_test.go @@ -0,0 +1,25 @@ +//go:build darwin && arm64 + +package metal + +import "testing" + +func TestCopy_BreaksGraph(t *testing.T) { + // Create a chain: a -> b -> c + a := FromValue(float32(1.0)) + b := Add(a, FromValue(float32(2.0))) + Eval(b) + + // Copy should break the graph + c := Copy(b) + Eval(c) + + // Free b — if Copy truly detaches, c should still be valid + Free(b) + + val := c.Float() + if val != 3.0 { + t.Fatalf("expected 3.0, got %f", val) + } + Free(a, c) +} diff --git a/internal/metal/detach.cpp b/internal/metal/detach.cpp new file mode 100644 index 0000000..8223d0f --- /dev/null +++ b/internal/metal/detach.cpp @@ -0,0 +1,8 @@ +#include "mlx/mlx.h" +#include "mlx/c/array.h" + +extern "C" void mlx_array_detach_impl(mlx_array arr) { + if (arr.ctx) { + static_cast(arr.ctx)->detach(); + } +} diff --git a/internal/metal/detach.go b/internal/metal/detach.go new file mode 100644 index 0000000..470bc70 --- /dev/null +++ b/internal/metal/detach.go @@ -0,0 +1,22 @@ +//go:build darwin && arm64 + +package metal + +/* +#include "mlx/c/array.h" + +// mlx_array_detach breaks an evaluated array's graph connections. +// ctx is a mlx::core::array* — we call detach() via a C++ helper. +void mlx_array_detach_impl(mlx_array arr); +*/ +import "C" + +// Detach breaks an array's graph connections after evaluation. +// This allows Metal memory from parent operations to be freed. +func Detach(arrays ...*Array) { + for _, a := range arrays { + if a != nil && a.ctx.ctx != nil { + C.mlx_array_detach_impl(a.ctx) + } + } +} diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go index f37b14e..326e157 100644 --- a/internal/metal/gemma3.go +++ b/internal/metal/gemma3.go @@ -330,78 +330,121 @@ func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) * B, L := shape[0], shape[1] h := m.EmbedTokens.Forward(tokens) - h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize)))) + s := float32(math.Sqrt(float64(m.Cfg.HiddenSize))) + h2 := MulScalar(h, s) + Free(h) + h = h2 for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, mask, m.Cfg) + hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg) + Free(h) + h = hNext } - return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)) + normed := RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps) + out := m.Output.Forward(normed) + Free(h, normed) + return out } func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array { normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps) attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg) - attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) - h := Add(x, attnOut) + Free(normed) + attnOutNormed := RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps) + Free(attnOut) + h := Add(x, attnOutNormed) + Free(attnOutNormed) - normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) - mlpOut := l.MLP.forward(normed) - mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) - return Add(h, mlpOut) + normed2 := RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps) + mlpOut := l.MLP.forward(normed2) + Free(normed2) + mlpOutNormed := RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps) + Free(mlpOut) + result := Add(h, mlpOutNormed) + Free(h, mlpOutNormed) + return result } func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) + qProj := a.QProj.Forward(x) + kProj := a.KProj.Forward(x) + vProj := a.VProj.Forward(x) // Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation. - // Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory), - // seq = H*D (jump one full row of heads), elem = 1 (contiguous within head). - q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + // AsStrided creates a view (C refcount keeps source alive), so Free source after. + q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + Free(qProj) + k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + Free(kProj) + v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + Free(vProj) // Q/K normalization + oldQ := q q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps) + Free(oldQ) + oldK := k k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps) + Free(oldK) // RoPE with appropriate theta ropeTheta := cfg.RopeTheta if isSliding { ropeTheta = cfg.RopeLocalBaseFreq } + oldQ = q q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + Free(oldQ) + oldK = k k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset()) + Free(oldK) - // Update cache + // Update cache — returns Slice views into cache buffer; free our pre-update handles. + oldK, oldV := k, v k, v = c.Update(k, v, int(L)) + Free(oldK, oldV) // GQA: repeat K/V heads repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + kAttn, vAttn := k, v if repeatFactor > 1 { - k = RepeatKV(k, repeatFactor) - v = RepeatKV(v, repeatFactor) + kAttn = RepeatKV(k, repeatFactor) + vAttn = RepeatKV(v, repeatFactor) + Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies } // Scaled dot-product attention var out *Array if mask != nil { - out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale) + out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale) } else { - out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1) } - out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) + Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views + + transposed := Transpose(out, 0, 2, 1, 3) + Free(out) + reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim) + Free(transposed) + result := a.OProj.Forward(reshaped) + Free(reshaped) + return result } func (m *MLP) forward(x *Array) *Array { - gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0] - return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x))) + gateProj := m.GateProj.Forward(x) + gate := getCompiledGELU().Call(gateProj)[0] + Free(gateProj) + upProj := m.UpProj.Forward(x) + activated := Mul(gate, upProj) + Free(gate, upProj) + result := m.DownProj.Forward(activated) + Free(activated) + return result } // NewCache creates per-layer caches for generation. diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 066cc64..44b218b 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -177,6 +177,13 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) m.lastErr = fmt.Errorf("prefill: %w", err) return } + // Detach logits and cache arrays to release the entire prefill computation + // graph. After Eval, data is materialised — graph connections only pin Metal + // memory from intermediate tensors (34 layers × ~20 ops each). + Detach(logits) + for _, c := range caches { + c.Detach() + } prefillDur = time.Since(prefillStart) // Track generated token IDs for repeat penalty. @@ -248,6 +255,15 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) m.lastErr = fmt.Errorf("decode step %d: %w", i, err) return } + + // Detach logits and cache arrays to break the computation graph. + // Without this, each step's logits holds shared_ptrs through the + // entire forward pass (SDPA → Slice → cache), pinning hundreds of + // Metal buffers per step that accumulate to tens of GB. + Detach(logits) + for _, c := range caches { + c.Detach() + } } } } @@ -264,9 +280,11 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention defer freeCaches(caches) // Single prefill pass — populates KV caches for all layers. - input := FromValues(tokens, len(tokens)) - input = Reshape(input, 1, int32(len(tokens))) + vInput := FromValues(tokens, len(tokens)) + input := Reshape(vInput, 1, int32(len(tokens))) + Free(vInput) logits := m.model.Forward(input, caches) + Free(input) if err := Eval(logits); err != nil { return nil, fmt.Errorf("prefill: %w", err) } diff --git a/internal/metal/lora.go b/internal/metal/lora.go index 2bcb661..0fee1d2 100644 --- a/internal/metal/lora.go +++ b/internal/metal/lora.go @@ -91,11 +91,20 @@ func (l *LoRALinear) Forward(x *Array) *Array { baseOut := l.Base.baseForward(x) // LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out] - loraOut := Matmul(x, Transpose(l.A)) - loraOut = Matmul(loraOut, Transpose(l.B)) - loraOut = MulScalar(loraOut, l.Scale) + ta := Transpose(l.A) + loraOut := Matmul(x, ta) + Free(ta) - return Add(baseOut, loraOut) + tb := Transpose(l.B) + loraOut2 := Matmul(loraOut, tb) + Free(loraOut, tb) + + loraOut3 := MulScalar(loraOut2, l.Scale) + Free(loraOut2) + + res := Add(baseOut, loraOut3) + Free(baseOut, loraOut3) + return res } // TrainableParams returns the LoRA A and B arrays for gradient computation. diff --git a/internal/metal/nn.go b/internal/metal/nn.go index a29a75d..41b52ba 100644 --- a/internal/metal/nn.go +++ b/internal/metal/nn.go @@ -50,10 +50,14 @@ func (l *Linear) baseForward(x *Array) *Array { if l.Scales != nil { out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits) } else { - out = Matmul(x, Transpose(l.Weight)) + wT := Transpose(l.Weight) + out = Matmul(x, wT) + Free(wT) } if l.Bias != nil && l.Bias.Valid() { + oldOut := out out = Add(out, l.Bias) + Free(oldOut) } return out } @@ -72,7 +76,9 @@ type Embedding struct { func (e *Embedding) Forward(indices *Array) *Array { if e.Scales != nil { w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits) - return Take(w, indices, 0) + res := Take(w, indices, 0) + Free(w) + return res } return Take(e.Weight, indices, 0) } @@ -110,6 +116,10 @@ func RepeatKV(x *Array, factor int32) *Array { // Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D] expanded := ExpandDims(x, 2) - expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D}) - return Reshape(expanded, B, H*factor, L, D) + broadcasted := BroadcastTo(expanded, []int32{B, H, factor, L, D}) + Free(expanded) + + res := Reshape(broadcasted, B, H*factor, L, D) + Free(broadcasted) + return res } diff --git a/internal/metal/ops.go b/internal/metal/ops.go index 53795d2..0723988 100644 --- a/internal/metal/ops.go +++ b/internal/metal/ops.go @@ -63,6 +63,15 @@ func Negative(a *Array) *Array { return out } +// Copy creates a deep copy of an array, breaking the computation graph chain. +// The returned array has the same data but no references to parent graph nodes, +// allowing Metal memory from prior graph operations to be freed. +func Copy(a *Array) *Array { + out := newArray("COPY", a) + C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx) + return out +} + // --- Math functions --- // Exp returns element-wise exp(a). @@ -81,7 +90,10 @@ func Sigmoid(a *Array) *Array { // SiLU returns element-wise x * sigmoid(x) (Swish activation). func SiLU(a *Array) *Array { - return Mul(a, Sigmoid(a)) + s := Sigmoid(a) + res := Mul(a, s) + Free(s) + return res } // Tanh returns element-wise tanh(a). diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go index f8daef9..f7a7abf 100644 --- a/internal/metal/qwen3.go +++ b/internal/metal/qwen3.go @@ -261,74 +261,114 @@ func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) * h := m.EmbedTokens.Forward(tokens) for i, layer := range m.Layers { - h = layer.forward(h, caches[i], B, L, mask, m.Cfg) + hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg) + Free(h) + h = hNext } - return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps)) + normed := m.Norm.Forward(h, m.Cfg.RMSNormEps) + out := m.Output.Forward(normed) + Free(h, normed) + return out } func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array { // Pre-attention norm → attention → residual add normed := l.InputNorm.Forward(x, cfg.RMSNormEps) attnOut := l.Attention.forward(normed, c, B, L, mask, cfg) + Free(normed) h := Add(x, attnOut) + Free(attnOut) // Pre-MLP norm → MLP → residual add - normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps) - mlpOut := l.MLP.forward(normed) - return Add(h, mlpOut) + normed2 := l.PostAttnNorm.Forward(h, cfg.RMSNormEps) + mlpOut := l.MLP.forward(normed2) + Free(normed2) + result := Add(h, mlpOut) + Free(h, mlpOut) + return result } func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array { - q := a.QProj.Forward(x) - k := a.KProj.Forward(x) - v := a.VProj.Forward(x) + qProj := a.QProj.Forward(x) + kProj := a.KProj.Forward(x) + vProj := a.VProj.Forward(x) - // Reshape to [B, num_heads, L, head_dim] - q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, + // Reshape to [B, num_heads, L, head_dim] via stride manipulation. + // AsStrided creates a view (C refcount keeps source alive), so Free source after. + q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0) - k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + Free(qProj) + k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) - v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, + Free(kProj) + v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim}, []int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0) + Free(vProj) // Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not) if a.QNorm != nil && a.QNorm.Weight != nil { + oldQ := q q = a.QNorm.Forward(q, cfg.RMSNormEps) + Free(oldQ) } if a.KNorm != nil && a.KNorm.Weight != nil { + oldK := k k = a.KNorm.Forward(k, cfg.RMSNormEps) + Free(oldK) } // RoPE — single theta for all layers (no sliding window) + oldQ := q q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + Free(oldQ) + oldK := k k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset()) + Free(oldK) - // Update KV cache + // Update KV cache — returns Slice views into cache buffer; free our pre-update handles. + oldK, oldV := k, v k, v = c.Update(k, v, int(L)) + Free(oldK, oldV) // GQA: repeat K/V heads to match Q heads repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads + kAttn, vAttn := k, v if repeatFactor > 1 { - k = RepeatKV(k, repeatFactor) - v = RepeatKV(v, repeatFactor) + kAttn = RepeatKV(k, repeatFactor) + vAttn = RepeatKV(v, repeatFactor) + Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies } // Scaled dot-product attention var out *Array if mask != nil { - out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale) + out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale) } else { - out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1) + out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1) } - out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim) - return a.OProj.Forward(out) + Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views + + transposed := Transpose(out, 0, 2, 1, 3) + Free(out) + reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim) + Free(transposed) + result := a.OProj.Forward(reshaped) + Free(reshaped) + return result } // forward computes SwiGLU: down(silu(gate(x)) * up(x)). func (m *Qwen3MLP) forward(x *Array) *Array { - gate := SiLU(m.GateProj.Forward(x)) - return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x))) + gateProj := m.GateProj.Forward(x) + gate := SiLU(gateProj) + Free(gateProj) + upProj := m.UpProj.Forward(x) + activated := Mul(gate, upProj) + Free(gate, upProj) + result := m.DownProj.Forward(activated) + Free(activated) + return result } // NewCache creates per-layer KV caches. Qwen 3 uses global attention only.