diff --git a/dist/include/metal_cpp/README.md b/dist/include/metal_cpp/README.md
index 52ae7b5..03d628c 100644
--- a/dist/include/metal_cpp/README.md
+++ b/dist/include/metal_cpp/README.md
@@ -22,7 +22,7 @@
| macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. |
| macOS 14, iOS 17 | Add support for the **MetalFX** framework.
Add all the APIs in macOS 14 and iOS 17. |
| macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. |
-| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
newArray optional `NS::SharedPtr` type to assist with memory management.
newArray convenience function to create a `CA::MetalLayer`.
newArray `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
+| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
New optional `NS::SharedPtr` type to assist with memory management.
New convenience function to create a `CA::MetalLayer`.
New `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
| macOS 12, iOS 15 | Initial release. |
## Memory Allocation Policy
diff --git a/internal/metal/cache.go b/internal/metal/cache.go
index 9f5f581..d4c7a7c 100644
--- a/internal/metal/cache.go
+++ b/internal/metal/cache.go
@@ -14,6 +14,9 @@ type Cache interface {
State() []*Array
// Reset clears the cache for a new generation session.
Reset()
+ // Detach replaces internal K/V arrays with copies that have no graph parents.
+ // Call after Eval to allow Metal memory from prior graph operations to be freed.
+ Detach()
}
// KVCache implements an unbounded cache that grows as needed.
@@ -90,6 +93,13 @@ func (c *KVCache) Reset() {
c.offset = 0
}
+func (c *KVCache) Detach() {
+ if c.keys == nil {
+ return
+ }
+ Detach(c.keys, c.values)
+}
+
// RotatingKVCache implements a bounded sliding window cache.
type RotatingKVCache struct {
keys, values *Array
@@ -190,7 +200,10 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array)
}
c.idx = int(c.keys.Shape()[2])
- return c.keys, c.values
+ // Return Slice views so callers can Free them without destroying the cache.
+ // (updateInPlace and KVCache.Update already return Slice views.)
+ return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dk}),
+ Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dv})
}
func (c *RotatingKVCache) State() []*Array {
@@ -209,3 +222,10 @@ func (c *RotatingKVCache) Reset() {
c.offset = 0
c.idx = 0
}
+
+func (c *RotatingKVCache) Detach() {
+ if c.keys == nil {
+ return
+ }
+ Detach(c.keys, c.values)
+}
diff --git a/internal/metal/copy_test.go b/internal/metal/copy_test.go
new file mode 100644
index 0000000..3fdb5d1
--- /dev/null
+++ b/internal/metal/copy_test.go
@@ -0,0 +1,25 @@
+//go:build darwin && arm64
+
+package metal
+
+import "testing"
+
+func TestCopy_BreaksGraph(t *testing.T) {
+ // Create a chain: a -> b -> c
+ a := FromValue(float32(1.0))
+ b := Add(a, FromValue(float32(2.0)))
+ Eval(b)
+
+ // Copy should break the graph
+ c := Copy(b)
+ Eval(c)
+
+ // Free b — if Copy truly detaches, c should still be valid
+ Free(b)
+
+ val := c.Float()
+ if val != 3.0 {
+ t.Fatalf("expected 3.0, got %f", val)
+ }
+ Free(a, c)
+}
diff --git a/internal/metal/detach.cpp b/internal/metal/detach.cpp
new file mode 100644
index 0000000..8223d0f
--- /dev/null
+++ b/internal/metal/detach.cpp
@@ -0,0 +1,8 @@
+#include "mlx/mlx.h"
+#include "mlx/c/array.h"
+
+extern "C" void mlx_array_detach_impl(mlx_array arr) {
+ if (arr.ctx) {
+ static_cast(arr.ctx)->detach();
+ }
+}
diff --git a/internal/metal/detach.go b/internal/metal/detach.go
new file mode 100644
index 0000000..470bc70
--- /dev/null
+++ b/internal/metal/detach.go
@@ -0,0 +1,22 @@
+//go:build darwin && arm64
+
+package metal
+
+/*
+#include "mlx/c/array.h"
+
+// mlx_array_detach breaks an evaluated array's graph connections.
+// ctx is a mlx::core::array* — we call detach() via a C++ helper.
+void mlx_array_detach_impl(mlx_array arr);
+*/
+import "C"
+
+// Detach breaks an array's graph connections after evaluation.
+// This allows Metal memory from parent operations to be freed.
+func Detach(arrays ...*Array) {
+ for _, a := range arrays {
+ if a != nil && a.ctx.ctx != nil {
+ C.mlx_array_detach_impl(a.ctx)
+ }
+ }
+}
diff --git a/internal/metal/gemma3.go b/internal/metal/gemma3.go
index f37b14e..326e157 100644
--- a/internal/metal/gemma3.go
+++ b/internal/metal/gemma3.go
@@ -330,78 +330,121 @@ func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
B, L := shape[0], shape[1]
h := m.EmbedTokens.Forward(tokens)
- h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
+ s := float32(math.Sqrt(float64(m.Cfg.HiddenSize)))
+ h2 := MulScalar(h, s)
+ Free(h)
+ h = h2
for i, layer := range m.Layers {
- h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
+ hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
+ Free(h)
+ h = hNext
}
- return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
+ normed := RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)
+ out := m.Output.Forward(normed)
+ Free(h, normed)
+ return out
}
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array {
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg)
- attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
- h := Add(x, attnOut)
+ Free(normed)
+ attnOutNormed := RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
+ Free(attnOut)
+ h := Add(x, attnOutNormed)
+ Free(attnOutNormed)
- normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
- mlpOut := l.MLP.forward(normed)
- mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
- return Add(h, mlpOut)
+ normed2 := RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
+ mlpOut := l.MLP.forward(normed2)
+ Free(normed2)
+ mlpOutNormed := RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
+ Free(mlpOut)
+ result := Add(h, mlpOutNormed)
+ Free(h, mlpOutNormed)
+ return result
}
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array {
- q := a.QProj.Forward(x)
- k := a.KProj.Forward(x)
- v := a.VProj.Forward(x)
+ qProj := a.QProj.Forward(x)
+ kProj := a.KProj.Forward(x)
+ vProj := a.VProj.Forward(x)
// Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation.
- // Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory),
- // seq = H*D (jump one full row of heads), elem = 1 (contiguous within head).
- q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
+ // AsStrided creates a view (C refcount keeps source alive), so Free source after.
+ q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
- k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
+ Free(qProj)
+ k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
- v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
+ Free(kProj)
+ v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
+ Free(vProj)
// Q/K normalization
+ oldQ := q
q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
+ Free(oldQ)
+ oldK := k
k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
+ Free(oldK)
// RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
+ oldQ = q
q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
+ Free(oldQ)
+ oldK = k
k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
+ Free(oldK)
- // Update cache
+ // Update cache — returns Slice views into cache buffer; free our pre-update handles.
+ oldK, oldV := k, v
k, v = c.Update(k, v, int(L))
+ Free(oldK, oldV)
// GQA: repeat K/V heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
+ kAttn, vAttn := k, v
if repeatFactor > 1 {
- k = RepeatKV(k, repeatFactor)
- v = RepeatKV(v, repeatFactor)
+ kAttn = RepeatKV(k, repeatFactor)
+ vAttn = RepeatKV(v, repeatFactor)
+ Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
}
// Scaled dot-product attention
var out *Array
if mask != nil {
- out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
+ out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
} else {
- out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
+ out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
}
- out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
- return a.OProj.Forward(out)
+ Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
+
+ transposed := Transpose(out, 0, 2, 1, 3)
+ Free(out)
+ reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
+ Free(transposed)
+ result := a.OProj.Forward(reshaped)
+ Free(reshaped)
+ return result
}
func (m *MLP) forward(x *Array) *Array {
- gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
- return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
+ gateProj := m.GateProj.Forward(x)
+ gate := getCompiledGELU().Call(gateProj)[0]
+ Free(gateProj)
+ upProj := m.UpProj.Forward(x)
+ activated := Mul(gate, upProj)
+ Free(gate, upProj)
+ result := m.DownProj.Forward(activated)
+ Free(activated)
+ return result
}
// NewCache creates per-layer caches for generation.
diff --git a/internal/metal/generate.go b/internal/metal/generate.go
index 066cc64..44b218b 100644
--- a/internal/metal/generate.go
+++ b/internal/metal/generate.go
@@ -177,6 +177,13 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
m.lastErr = fmt.Errorf("prefill: %w", err)
return
}
+ // Detach logits and cache arrays to release the entire prefill computation
+ // graph. After Eval, data is materialised — graph connections only pin Metal
+ // memory from intermediate tensors (34 layers × ~20 ops each).
+ Detach(logits)
+ for _, c := range caches {
+ c.Detach()
+ }
prefillDur = time.Since(prefillStart)
// Track generated token IDs for repeat penalty.
@@ -248,6 +255,15 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
return
}
+
+ // Detach logits and cache arrays to break the computation graph.
+ // Without this, each step's logits holds shared_ptrs through the
+ // entire forward pass (SDPA → Slice → cache), pinning hundreds of
+ // Metal buffers per step that accumulate to tens of GB.
+ Detach(logits)
+ for _, c := range caches {
+ c.Detach()
+ }
}
}
}
@@ -264,9 +280,11 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
defer freeCaches(caches)
// Single prefill pass — populates KV caches for all layers.
- input := FromValues(tokens, len(tokens))
- input = Reshape(input, 1, int32(len(tokens)))
+ vInput := FromValues(tokens, len(tokens))
+ input := Reshape(vInput, 1, int32(len(tokens)))
+ Free(vInput)
logits := m.model.Forward(input, caches)
+ Free(input)
if err := Eval(logits); err != nil {
return nil, fmt.Errorf("prefill: %w", err)
}
diff --git a/internal/metal/lora.go b/internal/metal/lora.go
index 2bcb661..0fee1d2 100644
--- a/internal/metal/lora.go
+++ b/internal/metal/lora.go
@@ -91,11 +91,20 @@ func (l *LoRALinear) Forward(x *Array) *Array {
baseOut := l.Base.baseForward(x)
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
- loraOut := Matmul(x, Transpose(l.A))
- loraOut = Matmul(loraOut, Transpose(l.B))
- loraOut = MulScalar(loraOut, l.Scale)
+ ta := Transpose(l.A)
+ loraOut := Matmul(x, ta)
+ Free(ta)
- return Add(baseOut, loraOut)
+ tb := Transpose(l.B)
+ loraOut2 := Matmul(loraOut, tb)
+ Free(loraOut, tb)
+
+ loraOut3 := MulScalar(loraOut2, l.Scale)
+ Free(loraOut2)
+
+ res := Add(baseOut, loraOut3)
+ Free(baseOut, loraOut3)
+ return res
}
// TrainableParams returns the LoRA A and B arrays for gradient computation.
diff --git a/internal/metal/nn.go b/internal/metal/nn.go
index a29a75d..41b52ba 100644
--- a/internal/metal/nn.go
+++ b/internal/metal/nn.go
@@ -50,10 +50,14 @@ func (l *Linear) baseForward(x *Array) *Array {
if l.Scales != nil {
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
} else {
- out = Matmul(x, Transpose(l.Weight))
+ wT := Transpose(l.Weight)
+ out = Matmul(x, wT)
+ Free(wT)
}
if l.Bias != nil && l.Bias.Valid() {
+ oldOut := out
out = Add(out, l.Bias)
+ Free(oldOut)
}
return out
}
@@ -72,7 +76,9 @@ type Embedding struct {
func (e *Embedding) Forward(indices *Array) *Array {
if e.Scales != nil {
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
- return Take(w, indices, 0)
+ res := Take(w, indices, 0)
+ Free(w)
+ return res
}
return Take(e.Weight, indices, 0)
}
@@ -110,6 +116,10 @@ func RepeatKV(x *Array, factor int32) *Array {
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
expanded := ExpandDims(x, 2)
- expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
- return Reshape(expanded, B, H*factor, L, D)
+ broadcasted := BroadcastTo(expanded, []int32{B, H, factor, L, D})
+ Free(expanded)
+
+ res := Reshape(broadcasted, B, H*factor, L, D)
+ Free(broadcasted)
+ return res
}
diff --git a/internal/metal/ops.go b/internal/metal/ops.go
index 53795d2..0723988 100644
--- a/internal/metal/ops.go
+++ b/internal/metal/ops.go
@@ -63,6 +63,15 @@ func Negative(a *Array) *Array {
return out
}
+// Copy creates a deep copy of an array, breaking the computation graph chain.
+// The returned array has the same data but no references to parent graph nodes,
+// allowing Metal memory from prior graph operations to be freed.
+func Copy(a *Array) *Array {
+ out := newArray("COPY", a)
+ C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
+ return out
+}
+
// --- Math functions ---
// Exp returns element-wise exp(a).
@@ -81,7 +90,10 @@ func Sigmoid(a *Array) *Array {
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
func SiLU(a *Array) *Array {
- return Mul(a, Sigmoid(a))
+ s := Sigmoid(a)
+ res := Mul(a, s)
+ Free(s)
+ return res
}
// Tanh returns element-wise tanh(a).
diff --git a/internal/metal/qwen3.go b/internal/metal/qwen3.go
index f8daef9..f7a7abf 100644
--- a/internal/metal/qwen3.go
+++ b/internal/metal/qwen3.go
@@ -261,74 +261,114 @@ func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
- h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
+ hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
+ Free(h)
+ h = hNext
}
- return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
+ normed := m.Norm.Forward(h, m.Cfg.RMSNormEps)
+ out := m.Output.Forward(normed)
+ Free(h, normed)
+ return out
}
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
// Pre-attention norm → attention → residual add
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
attnOut := l.Attention.forward(normed, c, B, L, mask, cfg)
+ Free(normed)
h := Add(x, attnOut)
+ Free(attnOut)
// Pre-MLP norm → MLP → residual add
- normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
- mlpOut := l.MLP.forward(normed)
- return Add(h, mlpOut)
+ normed2 := l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
+ mlpOut := l.MLP.forward(normed2)
+ Free(normed2)
+ result := Add(h, mlpOut)
+ Free(h, mlpOut)
+ return result
}
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
- q := a.QProj.Forward(x)
- k := a.KProj.Forward(x)
- v := a.VProj.Forward(x)
+ qProj := a.QProj.Forward(x)
+ kProj := a.KProj.Forward(x)
+ vProj := a.VProj.Forward(x)
- // Reshape to [B, num_heads, L, head_dim]
- q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
+ // Reshape to [B, num_heads, L, head_dim] via stride manipulation.
+ // AsStrided creates a view (C refcount keeps source alive), so Free source after.
+ q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
- k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
+ Free(qProj)
+ k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
- v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
+ Free(kProj)
+ v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
+ Free(vProj)
// Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not)
if a.QNorm != nil && a.QNorm.Weight != nil {
+ oldQ := q
q = a.QNorm.Forward(q, cfg.RMSNormEps)
+ Free(oldQ)
}
if a.KNorm != nil && a.KNorm.Weight != nil {
+ oldK := k
k = a.KNorm.Forward(k, cfg.RMSNormEps)
+ Free(oldK)
}
// RoPE — single theta for all layers (no sliding window)
+ oldQ := q
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
+ Free(oldQ)
+ oldK := k
k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
+ Free(oldK)
- // Update KV cache
+ // Update KV cache — returns Slice views into cache buffer; free our pre-update handles.
+ oldK, oldV := k, v
k, v = c.Update(k, v, int(L))
+ Free(oldK, oldV)
// GQA: repeat K/V heads to match Q heads
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
+ kAttn, vAttn := k, v
if repeatFactor > 1 {
- k = RepeatKV(k, repeatFactor)
- v = RepeatKV(v, repeatFactor)
+ kAttn = RepeatKV(k, repeatFactor)
+ vAttn = RepeatKV(v, repeatFactor)
+ Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
}
// Scaled dot-product attention
var out *Array
if mask != nil {
- out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
+ out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
} else {
- out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
+ out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
}
- out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
- return a.OProj.Forward(out)
+ Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
+
+ transposed := Transpose(out, 0, 2, 1, 3)
+ Free(out)
+ reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
+ Free(transposed)
+ result := a.OProj.Forward(reshaped)
+ Free(reshaped)
+ return result
}
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
func (m *Qwen3MLP) forward(x *Array) *Array {
- gate := SiLU(m.GateProj.Forward(x))
- return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
+ gateProj := m.GateProj.Forward(x)
+ gate := SiLU(gateProj)
+ Free(gateProj)
+ upProj := m.UpProj.Forward(x)
+ activated := Mul(gate, upProj)
+ Free(gate, upProj)
+ result := m.DownProj.Forward(activated)
+ Free(activated)
+ return result
}
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.