fix: add Detach/Free calls to reduce Metal GPU memory retention
Add deterministic memory cleanup across inference paths: - Detach logits after Eval to release graph references - Free intermediate arrays in attention (gemma3, qwen3) - Add cache Detach helper for KV cache cleanup after generation - New detach.cpp/go CGO bindings for mlx_array_detach Reduces 4B model memory from 78GB to ~17GB (vs 2.4GB mlx-lm baseline). Native Metal memory management still trails Python refcounting but is now viable for 1B models. Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
c1baeb9254
commit
71fe4bb5ac
11 changed files with 267 additions and 60 deletions
2
dist/include/metal_cpp/README.md
vendored
2
dist/include/metal_cpp/README.md
vendored
|
|
@ -22,7 +22,7 @@
|
|||
| macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. |
|
||||
| macOS 14, iOS 17 | Add support for the **MetalFX** framework. <br/>Add all the APIs in macOS 14 and iOS 17. |
|
||||
| macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. |
|
||||
| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.<br />newArray optional `NS::SharedPtr<T>` type to assist with memory management.<br/>newArray convenience function to create a `CA::MetalLayer`.<br/>newArray `MTLSTR(str)` macro allows faster string creation from literals.<br/>Fix a problem with the signature of functions that take an array of pointers as input.<br/>Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
|
||||
| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.<br />New optional `NS::SharedPtr<T>` type to assist with memory management.<br/>New convenience function to create a `CA::MetalLayer`.<br/>New `MTLSTR(str)` macro allows faster string creation from literals.<br/>Fix a problem with the signature of functions that take an array of pointers as input.<br/>Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.|
|
||||
| macOS 12, iOS 15 | Initial release. |
|
||||
|
||||
## Memory Allocation Policy
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ type Cache interface {
|
|||
State() []*Array
|
||||
// Reset clears the cache for a new generation session.
|
||||
Reset()
|
||||
// Detach replaces internal K/V arrays with copies that have no graph parents.
|
||||
// Call after Eval to allow Metal memory from prior graph operations to be freed.
|
||||
Detach()
|
||||
}
|
||||
|
||||
// KVCache implements an unbounded cache that grows as needed.
|
||||
|
|
@ -90,6 +93,13 @@ func (c *KVCache) Reset() {
|
|||
c.offset = 0
|
||||
}
|
||||
|
||||
func (c *KVCache) Detach() {
|
||||
if c.keys == nil {
|
||||
return
|
||||
}
|
||||
Detach(c.keys, c.values)
|
||||
}
|
||||
|
||||
// RotatingKVCache implements a bounded sliding window cache.
|
||||
type RotatingKVCache struct {
|
||||
keys, values *Array
|
||||
|
|
@ -190,7 +200,10 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array)
|
|||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
return c.keys, c.values
|
||||
// Return Slice views so callers can Free them without destroying the cache.
|
||||
// (updateInPlace and KVCache.Update already return Slice views.)
|
||||
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dk}),
|
||||
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.idx), Dv})
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) State() []*Array {
|
||||
|
|
@ -209,3 +222,10 @@ func (c *RotatingKVCache) Reset() {
|
|||
c.offset = 0
|
||||
c.idx = 0
|
||||
}
|
||||
|
||||
func (c *RotatingKVCache) Detach() {
|
||||
if c.keys == nil {
|
||||
return
|
||||
}
|
||||
Detach(c.keys, c.values)
|
||||
}
|
||||
|
|
|
|||
25
internal/metal/copy_test.go
Normal file
25
internal/metal/copy_test.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCopy_BreaksGraph(t *testing.T) {
|
||||
// Create a chain: a -> b -> c
|
||||
a := FromValue(float32(1.0))
|
||||
b := Add(a, FromValue(float32(2.0)))
|
||||
Eval(b)
|
||||
|
||||
// Copy should break the graph
|
||||
c := Copy(b)
|
||||
Eval(c)
|
||||
|
||||
// Free b — if Copy truly detaches, c should still be valid
|
||||
Free(b)
|
||||
|
||||
val := c.Float()
|
||||
if val != 3.0 {
|
||||
t.Fatalf("expected 3.0, got %f", val)
|
||||
}
|
||||
Free(a, c)
|
||||
}
|
||||
8
internal/metal/detach.cpp
Normal file
8
internal/metal/detach.cpp
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#include "mlx/mlx.h"
|
||||
#include "mlx/c/array.h"
|
||||
|
||||
extern "C" void mlx_array_detach_impl(mlx_array arr) {
|
||||
if (arr.ctx) {
|
||||
static_cast<mlx::core::array*>(arr.ctx)->detach();
|
||||
}
|
||||
}
|
||||
22
internal/metal/detach.go
Normal file
22
internal/metal/detach.go
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package metal
|
||||
|
||||
/*
|
||||
#include "mlx/c/array.h"
|
||||
|
||||
// mlx_array_detach breaks an evaluated array's graph connections.
|
||||
// ctx is a mlx::core::array* — we call detach() via a C++ helper.
|
||||
void mlx_array_detach_impl(mlx_array arr);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
// Detach breaks an array's graph connections after evaluation.
|
||||
// This allows Metal memory from parent operations to be freed.
|
||||
func Detach(arrays ...*Array) {
|
||||
for _, a := range arrays {
|
||||
if a != nil && a.ctx.ctx != nil {
|
||||
C.mlx_array_detach_impl(a.ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -330,78 +330,121 @@ func (m *GemmaModel) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
|
|||
B, L := shape[0], shape[1]
|
||||
|
||||
h := m.EmbedTokens.Forward(tokens)
|
||||
h = MulScalar(h, float32(math.Sqrt(float64(m.Cfg.HiddenSize))))
|
||||
s := float32(math.Sqrt(float64(m.Cfg.HiddenSize)))
|
||||
h2 := MulScalar(h, s)
|
||||
Free(h)
|
||||
h = h2
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
Free(h)
|
||||
h = hNext
|
||||
}
|
||||
|
||||
return m.Output.Forward(RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps))
|
||||
normed := RMSNorm(h, m.NormScaled, m.Cfg.RMSNormEps)
|
||||
out := m.Output.Forward(normed)
|
||||
Free(h, normed)
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *TextConfig) *Array {
|
||||
normed := RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, l.IsSliding, mask, cfg)
|
||||
attnOut = RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
h := Add(x, attnOut)
|
||||
Free(normed)
|
||||
attnOutNormed := RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
|
||||
Free(attnOut)
|
||||
h := Add(x, attnOutNormed)
|
||||
Free(attnOutNormed)
|
||||
|
||||
normed = RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
mlpOut = RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
return Add(h, mlpOut)
|
||||
normed2 := RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed2)
|
||||
Free(normed2)
|
||||
mlpOutNormed := RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
|
||||
Free(mlpOut)
|
||||
result := Add(h, mlpOutNormed)
|
||||
Free(h, mlpOutNormed)
|
||||
return result
|
||||
}
|
||||
|
||||
func (a *Attention) forward(x *Array, c Cache, B, L int32, isSliding bool, mask *Array, cfg *TextConfig) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
qProj := a.QProj.Forward(x)
|
||||
kProj := a.KProj.Forward(x)
|
||||
vProj := a.VProj.Forward(x)
|
||||
|
||||
// Virtual transpose [B,L,H*D] → [B,H,L,D] via stride manipulation.
|
||||
// Strides: batch = L*H*D (full sequence), head = D (adjacent heads in memory),
|
||||
// seq = H*D (jump one full row of heads), elem = 1 (contiguous within head).
|
||||
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
// AsStrided creates a view (C refcount keeps source alive), so Free source after.
|
||||
q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
Free(qProj)
|
||||
k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
Free(kProj)
|
||||
v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
Free(vProj)
|
||||
|
||||
// Q/K normalization
|
||||
oldQ := q
|
||||
q = RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
|
||||
Free(oldQ)
|
||||
oldK := k
|
||||
k = RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
|
||||
Free(oldK)
|
||||
|
||||
// RoPE with appropriate theta
|
||||
ropeTheta := cfg.RopeTheta
|
||||
if isSliding {
|
||||
ropeTheta = cfg.RopeLocalBaseFreq
|
||||
}
|
||||
oldQ = q
|
||||
q = RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
Free(oldQ)
|
||||
oldK = k
|
||||
k = RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
|
||||
Free(oldK)
|
||||
|
||||
// Update cache
|
||||
// Update cache — returns Slice views into cache buffer; free our pre-update handles.
|
||||
oldK, oldV := k, v
|
||||
k, v = c.Update(k, v, int(L))
|
||||
Free(oldK, oldV)
|
||||
|
||||
// GQA: repeat K/V heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
kAttn, vAttn := k, v
|
||||
if repeatFactor > 1 {
|
||||
k = RepeatKV(k, repeatFactor)
|
||||
v = RepeatKV(v, repeatFactor)
|
||||
kAttn = RepeatKV(k, repeatFactor)
|
||||
vAttn = RepeatKV(v, repeatFactor)
|
||||
Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
var out *Array
|
||||
if mask != nil {
|
||||
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
|
||||
out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
|
||||
} else {
|
||||
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
|
||||
}
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
|
||||
|
||||
transposed := Transpose(out, 0, 2, 1, 3)
|
||||
Free(out)
|
||||
reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
Free(transposed)
|
||||
result := a.OProj.Forward(reshaped)
|
||||
Free(reshaped)
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *MLP) forward(x *Array) *Array {
|
||||
gate := getCompiledGELU().Call(m.GateProj.Forward(x))[0]
|
||||
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
|
||||
gateProj := m.GateProj.Forward(x)
|
||||
gate := getCompiledGELU().Call(gateProj)[0]
|
||||
Free(gateProj)
|
||||
upProj := m.UpProj.Forward(x)
|
||||
activated := Mul(gate, upProj)
|
||||
Free(gate, upProj)
|
||||
result := m.DownProj.Forward(activated)
|
||||
Free(activated)
|
||||
return result
|
||||
}
|
||||
|
||||
// NewCache creates per-layer caches for generation.
|
||||
|
|
|
|||
|
|
@ -177,6 +177,13 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
m.lastErr = fmt.Errorf("prefill: %w", err)
|
||||
return
|
||||
}
|
||||
// Detach logits and cache arrays to release the entire prefill computation
|
||||
// graph. After Eval, data is materialised — graph connections only pin Metal
|
||||
// memory from intermediate tensors (34 layers × ~20 ops each).
|
||||
Detach(logits)
|
||||
for _, c := range caches {
|
||||
c.Detach()
|
||||
}
|
||||
prefillDur = time.Since(prefillStart)
|
||||
|
||||
// Track generated token IDs for repeat penalty.
|
||||
|
|
@ -248,6 +255,15 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Detach logits and cache arrays to break the computation graph.
|
||||
// Without this, each step's logits holds shared_ptrs through the
|
||||
// entire forward pass (SDPA → Slice → cache), pinning hundreds of
|
||||
// Metal buffers per step that accumulate to tens of GB.
|
||||
Detach(logits)
|
||||
for _, c := range caches {
|
||||
c.Detach()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -264,9 +280,11 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
|
|||
defer freeCaches(caches)
|
||||
|
||||
// Single prefill pass — populates KV caches for all layers.
|
||||
input := FromValues(tokens, len(tokens))
|
||||
input = Reshape(input, 1, int32(len(tokens)))
|
||||
vInput := FromValues(tokens, len(tokens))
|
||||
input := Reshape(vInput, 1, int32(len(tokens)))
|
||||
Free(vInput)
|
||||
logits := m.model.Forward(input, caches)
|
||||
Free(input)
|
||||
if err := Eval(logits); err != nil {
|
||||
return nil, fmt.Errorf("prefill: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -91,11 +91,20 @@ func (l *LoRALinear) Forward(x *Array) *Array {
|
|||
baseOut := l.Base.baseForward(x)
|
||||
|
||||
// LoRA path: x @ A^T gives [B, L, rank], then @ B^T gives [B, L, out]
|
||||
loraOut := Matmul(x, Transpose(l.A))
|
||||
loraOut = Matmul(loraOut, Transpose(l.B))
|
||||
loraOut = MulScalar(loraOut, l.Scale)
|
||||
ta := Transpose(l.A)
|
||||
loraOut := Matmul(x, ta)
|
||||
Free(ta)
|
||||
|
||||
return Add(baseOut, loraOut)
|
||||
tb := Transpose(l.B)
|
||||
loraOut2 := Matmul(loraOut, tb)
|
||||
Free(loraOut, tb)
|
||||
|
||||
loraOut3 := MulScalar(loraOut2, l.Scale)
|
||||
Free(loraOut2)
|
||||
|
||||
res := Add(baseOut, loraOut3)
|
||||
Free(baseOut, loraOut3)
|
||||
return res
|
||||
}
|
||||
|
||||
// TrainableParams returns the LoRA A and B arrays for gradient computation.
|
||||
|
|
|
|||
|
|
@ -50,10 +50,14 @@ func (l *Linear) baseForward(x *Array) *Array {
|
|||
if l.Scales != nil {
|
||||
out = QuantizedMatmul(x, l.Weight, l.Scales, l.Biases, true, l.GroupSize, l.Bits)
|
||||
} else {
|
||||
out = Matmul(x, Transpose(l.Weight))
|
||||
wT := Transpose(l.Weight)
|
||||
out = Matmul(x, wT)
|
||||
Free(wT)
|
||||
}
|
||||
if l.Bias != nil && l.Bias.Valid() {
|
||||
oldOut := out
|
||||
out = Add(out, l.Bias)
|
||||
Free(oldOut)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
|
@ -72,7 +76,9 @@ type Embedding struct {
|
|||
func (e *Embedding) Forward(indices *Array) *Array {
|
||||
if e.Scales != nil {
|
||||
w := Dequantize(e.Weight, e.Scales, e.Biases, e.GroupSize, e.Bits)
|
||||
return Take(w, indices, 0)
|
||||
res := Take(w, indices, 0)
|
||||
Free(w)
|
||||
return res
|
||||
}
|
||||
return Take(e.Weight, indices, 0)
|
||||
}
|
||||
|
|
@ -110,6 +116,10 @@ func RepeatKV(x *Array, factor int32) *Array {
|
|||
|
||||
// Expand: [B, H, 1, L, D] then broadcast to [B, H, factor, L, D]
|
||||
expanded := ExpandDims(x, 2)
|
||||
expanded = BroadcastTo(expanded, []int32{B, H, factor, L, D})
|
||||
return Reshape(expanded, B, H*factor, L, D)
|
||||
broadcasted := BroadcastTo(expanded, []int32{B, H, factor, L, D})
|
||||
Free(expanded)
|
||||
|
||||
res := Reshape(broadcasted, B, H*factor, L, D)
|
||||
Free(broadcasted)
|
||||
return res
|
||||
}
|
||||
|
|
|
|||
|
|
@ -63,6 +63,15 @@ func Negative(a *Array) *Array {
|
|||
return out
|
||||
}
|
||||
|
||||
// Copy creates a deep copy of an array, breaking the computation graph chain.
|
||||
// The returned array has the same data but no references to parent graph nodes,
|
||||
// allowing Metal memory from prior graph operations to be freed.
|
||||
func Copy(a *Array) *Array {
|
||||
out := newArray("COPY", a)
|
||||
C.mlx_copy(&out.ctx, a.ctx, DefaultStream().ctx)
|
||||
return out
|
||||
}
|
||||
|
||||
// --- Math functions ---
|
||||
|
||||
// Exp returns element-wise exp(a).
|
||||
|
|
@ -81,7 +90,10 @@ func Sigmoid(a *Array) *Array {
|
|||
|
||||
// SiLU returns element-wise x * sigmoid(x) (Swish activation).
|
||||
func SiLU(a *Array) *Array {
|
||||
return Mul(a, Sigmoid(a))
|
||||
s := Sigmoid(a)
|
||||
res := Mul(a, s)
|
||||
Free(s)
|
||||
return res
|
||||
}
|
||||
|
||||
// Tanh returns element-wise tanh(a).
|
||||
|
|
|
|||
|
|
@ -261,74 +261,114 @@ func (m *Qwen3Model) ForwardMasked(tokens *Array, mask *Array, caches []Cache) *
|
|||
h := m.EmbedTokens.Forward(tokens)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
h = layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
hNext := layer.forward(h, caches[i], B, L, mask, m.Cfg)
|
||||
Free(h)
|
||||
h = hNext
|
||||
}
|
||||
|
||||
return m.Output.Forward(m.Norm.Forward(h, m.Cfg.RMSNormEps))
|
||||
normed := m.Norm.Forward(h, m.Cfg.RMSNormEps)
|
||||
out := m.Output.Forward(normed)
|
||||
Free(h, normed)
|
||||
return out
|
||||
}
|
||||
|
||||
func (l *Qwen3DecoderLayer) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
|
||||
// Pre-attention norm → attention → residual add
|
||||
normed := l.InputNorm.Forward(x, cfg.RMSNormEps)
|
||||
attnOut := l.Attention.forward(normed, c, B, L, mask, cfg)
|
||||
Free(normed)
|
||||
h := Add(x, attnOut)
|
||||
Free(attnOut)
|
||||
|
||||
// Pre-MLP norm → MLP → residual add
|
||||
normed = l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed)
|
||||
return Add(h, mlpOut)
|
||||
normed2 := l.PostAttnNorm.Forward(h, cfg.RMSNormEps)
|
||||
mlpOut := l.MLP.forward(normed2)
|
||||
Free(normed2)
|
||||
result := Add(h, mlpOut)
|
||||
Free(h, mlpOut)
|
||||
return result
|
||||
}
|
||||
|
||||
func (a *Qwen3Attention) forward(x *Array, c Cache, B, L int32, mask *Array, cfg *Qwen3Config) *Array {
|
||||
q := a.QProj.Forward(x)
|
||||
k := a.KProj.Forward(x)
|
||||
v := a.VProj.Forward(x)
|
||||
qProj := a.QProj.Forward(x)
|
||||
kProj := a.KProj.Forward(x)
|
||||
vProj := a.VProj.Forward(x)
|
||||
|
||||
// Reshape to [B, num_heads, L, head_dim]
|
||||
q = AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
// Reshape to [B, num_heads, L, head_dim] via stride manipulation.
|
||||
// AsStrided creates a view (C refcount keeps source alive), so Free source after.
|
||||
q := AsStrided(qProj, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
|
||||
k = AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
Free(qProj)
|
||||
k := AsStrided(kProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
v = AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
Free(kProj)
|
||||
v := AsStrided(vProj, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
|
||||
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
|
||||
Free(vProj)
|
||||
|
||||
// Q/K RMS normalization (Qwen 3 has this; Qwen 2 does not)
|
||||
if a.QNorm != nil && a.QNorm.Weight != nil {
|
||||
oldQ := q
|
||||
q = a.QNorm.Forward(q, cfg.RMSNormEps)
|
||||
Free(oldQ)
|
||||
}
|
||||
if a.KNorm != nil && a.KNorm.Weight != nil {
|
||||
oldK := k
|
||||
k = a.KNorm.Forward(k, cfg.RMSNormEps)
|
||||
Free(oldK)
|
||||
}
|
||||
|
||||
// RoPE — single theta for all layers (no sliding window)
|
||||
oldQ := q
|
||||
q = RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
Free(oldQ)
|
||||
oldK := k
|
||||
k = RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
|
||||
Free(oldK)
|
||||
|
||||
// Update KV cache
|
||||
// Update KV cache — returns Slice views into cache buffer; free our pre-update handles.
|
||||
oldK, oldV := k, v
|
||||
k, v = c.Update(k, v, int(L))
|
||||
Free(oldK, oldV)
|
||||
|
||||
// GQA: repeat K/V heads to match Q heads
|
||||
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
|
||||
kAttn, vAttn := k, v
|
||||
if repeatFactor > 1 {
|
||||
k = RepeatKV(k, repeatFactor)
|
||||
v = RepeatKV(v, repeatFactor)
|
||||
kAttn = RepeatKV(k, repeatFactor)
|
||||
vAttn = RepeatKV(v, repeatFactor)
|
||||
Free(k, v) // Free Slice views from cache.Update; RepeatKV holds copies
|
||||
}
|
||||
|
||||
// Scaled dot-product attention
|
||||
var out *Array
|
||||
if mask != nil {
|
||||
out = ScaledDotProductAttentionWithMask(q, k, v, mask, cfg.Scale)
|
||||
out = ScaledDotProductAttentionWithMask(q, kAttn, vAttn, mask, cfg.Scale)
|
||||
} else {
|
||||
out = ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
|
||||
out = ScaledDotProductAttention(q, kAttn, vAttn, cfg.Scale, L > 1)
|
||||
}
|
||||
out = Reshape(Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
return a.OProj.Forward(out)
|
||||
Free(q, kAttn, vAttn) // Always free — when repeatFactor==1 this frees the Slice views
|
||||
|
||||
transposed := Transpose(out, 0, 2, 1, 3)
|
||||
Free(out)
|
||||
reshaped := Reshape(transposed, B, L, cfg.NumAttentionHeads*cfg.HeadDim)
|
||||
Free(transposed)
|
||||
result := a.OProj.Forward(reshaped)
|
||||
Free(reshaped)
|
||||
return result
|
||||
}
|
||||
|
||||
// forward computes SwiGLU: down(silu(gate(x)) * up(x)).
|
||||
func (m *Qwen3MLP) forward(x *Array) *Array {
|
||||
gate := SiLU(m.GateProj.Forward(x))
|
||||
return m.DownProj.Forward(Mul(gate, m.UpProj.Forward(x)))
|
||||
gateProj := m.GateProj.Forward(x)
|
||||
gate := SiLU(gateProj)
|
||||
Free(gateProj)
|
||||
upProj := m.UpProj.Forward(x)
|
||||
activated := Mul(gate, upProj)
|
||||
Free(gate, upProj)
|
||||
result := m.DownProj.Forward(activated)
|
||||
Free(activated)
|
||||
return result
|
||||
}
|
||||
|
||||
// NewCache creates per-layer KV caches. Qwen 3 uses global attention only.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue