fix: add deterministic GPU memory cleanup across inference paths
Some checks failed
Security Scan / security (push) Successful in 15s
Test / Vet & Build (push) Failing after 32s

- defer freeCaches() in Generate and InspectAttention
- Free orphaned arrays during KVCache growth and slice updates
- Free per-token scalar intermediates in samplers and ops
- Free intermediate arrays in applyRepeatPenalty

Found by 3-way review: Claude explorer, Codex (gpt-5.3), Gemini Ultra.
Gemini implemented the fixes.

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 04:37:51 +00:00
parent 208f76b067
commit 51ac442a09
4 changed files with 109 additions and 49 deletions

View file

@ -50,20 +50,25 @@ func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
oldK, oldV := c.keys, c.values
if prev%c.step != 0 {
c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
oldK = Slice(oldK, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
oldV = Slice(oldV, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
Free(c.keys, c.values)
}
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
c.values = Concatenate([]*Array{c.values, newV}, 2)
c.keys = Concatenate([]*Array{oldK, newK}, 2)
c.values = Concatenate([]*Array{oldV, newV}, 2)
Free(oldK, oldV, newK, newV)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
oldK, oldV := c.keys, c.values
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
Free(oldK, oldV)
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
@ -127,8 +132,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
c.values = Concatenate([]*Array{c.values, newV}, 2)
oldK, oldV := c.keys, c.values
c.keys = Concatenate([]*Array{oldK, newK}, 2)
c.values = Concatenate([]*Array{oldV, newV}, 2)
Free(oldK, oldV, newK, newV)
} else {
c.keys, c.values = newK, newV
}
@ -138,8 +145,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
c.idx = 0
}
oldK, oldV := c.keys, c.values
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
Free(oldK, oldV)
c.offset++
c.idx++
@ -163,17 +172,21 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array)
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
c.keys, c.values = k.Clone(), v.Clone()
} else {
c.keys = Concatenate([]*Array{c.keys, k}, 2)
c.values = Concatenate([]*Array{c.values, v}, 2)
oldK, oldV := c.keys, c.values
c.keys = Concatenate([]*Array{oldK, k}, 2)
c.values = Concatenate([]*Array{oldV, v}, 2)
Free(oldK, oldV)
}
c.offset += seqLen
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
oldK, oldV := c.keys, c.values
c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
Free(oldK, oldV)
}
c.idx = int(c.keys.Shape()[2])

View file

@ -139,6 +139,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
tokens := m.tokenizer.Encode(prompt)
promptLen := len(tokens)
caches := m.newCaches()
defer freeCaches(caches)
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
var genCount int
var prefillDur time.Duration
@ -165,9 +167,11 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
// Prefill: process entire prompt
prefillStart := time.Now()
input := FromValues(tokens, len(tokens))
input = Reshape(input, 1, int32(len(tokens)))
vInput := FromValues(tokens, len(tokens))
input := Reshape(vInput, 1, int32(len(tokens)))
logits := m.model.Forward(input, caches)
Free(vInput, input)
if err := Eval(logits); err != nil {
m.lastErr = fmt.Errorf("prefill: %w", err)
return
@ -177,6 +181,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
// Track generated token IDs for repeat penalty.
var history []int32
defer func() {
Free(logits)
}()
for i := 0; i < cfg.MaxTokens; i++ {
select {
case <-ctx.Done():
@ -186,41 +194,55 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
}
// Sample from last position logits
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
l1 := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
lastPos := Reshape(l1, 1, int32(l1.Dim(2)))
Free(l1)
// Apply repeat penalty before sampling.
if cfg.RepeatPenalty > 1.0 && len(history) > 0 {
oldLastPos := lastPos
lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty)
Free(oldLastPos)
}
next := sampler.Sample(lastPos)
if err := Eval(next); err != nil {
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
Free(lastPos, next)
return
}
id := int32(next.Int())
history = append(history, id)
Free(lastPos)
// Check stop conditions
if id == m.tokenizer.EOSToken() {
Free(next)
return
}
if slices.Contains(cfg.StopTokens, id) {
Free(next)
return
}
genCount++
text := m.tokenizer.DecodeToken(id)
if !yield(Token{ID: id, Text: text}) {
Free(next)
return
}
Free(next)
// Next step input
nextInput := FromValues([]int32{id}, 1)
nextInput = Reshape(nextInput, 1, 1)
vNextInput := FromValues([]int32{id}, 1)
nextInput := Reshape(vNextInput, 1, 1)
Free(vNextInput)
oldLogits := logits
logits = m.model.Forward(nextInput, caches)
Free(nextInput, oldLogits)
if err := Eval(logits); err != nil {
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
return
@ -238,6 +260,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
}
caches := m.newCaches()
defer freeCaches(caches)
// Single prefill pass — populates KV caches for all layers.
input := FromValues(tokens, len(tokens))
@ -271,11 +294,14 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
validLen := min(caches[i].Len(), seqLen)
kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]})
if err := Eval(kSliced); err != nil {
Free(kSliced)
continue
}
// Extract all floats then reshape per head.
flat := kSliced.Floats() // len = 1 * H * validLen * D
Free(kSliced)
keys[i] = make([][]float32, numHeads)
stride := validLen * headDim
for h := 0; h < numHeads; h++ {
@ -332,13 +358,15 @@ func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array
penaltyVal := FromValue(penalty)
// Positive logits: divide by penalty. Negative logits: multiply by penalty.
penalised := Where(
Greater(gathered, zero),
Mul(gathered, invPenalty),
Mul(gathered, penaltyVal),
)
gt := Greater(gathered, zero)
m1 := Mul(gathered, invPenalty)
m2 := Mul(gathered, penaltyVal)
penalised := Where(gt, m1, m2)
Free(gt, m1, m2)
return PutAlongAxis(logits, idx, penalised, -1)
res := PutAlongAxis(logits, idx, penalised, -1)
Free(idx, gathered, zero, invPenalty, penaltyVal, penalised)
return res
}
// newCaches creates per-layer KV caches. If contextLen is set, all unbounded

View file

@ -22,7 +22,9 @@ func Add(a, b *Array) *Array {
// AddScalar returns a + scalar (broadcast).
func AddScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Add(a, scalar)
res := Add(a, scalar)
Free(scalar)
return res
}
// Mul returns element-wise a * b.
@ -35,7 +37,9 @@ func Mul(a, b *Array) *Array {
// MulScalar returns a * scalar (broadcast).
func MulScalar(a *Array, s float32) *Array {
scalar := FromValue(s)
return Mul(a, scalar)
res := Mul(a, scalar)
Free(scalar)
return res
}
// Divide returns element-wise a / b.

View file

@ -36,11 +36,20 @@ func newSampler(temp, topP, minP float32, topK int) Sampler {
type chain []Sampler
func (c chain) Sample(logits *Array) *Array {
curr := logits
for _, s := range c {
logits = s.Sample(logits)
next := s.Sample(curr)
if curr != logits {
Free(curr)
}
curr = next
}
// Final categorical sample from log-probabilities
return RandomCategorical(logits)
res := RandomCategorical(curr)
if curr != logits {
Free(curr)
}
return res
}
// greedy returns the argmax token.
@ -62,10 +71,15 @@ type TopKSampler int
func (k TopKSampler) Sample(logits *Array) *Array {
neg := Negative(logits)
mask := Argpartition(neg, int(k)-1, -1)
maskIdx := Argpartition(neg, int(k)-1, -1)
Free(neg)
// Slice the indices beyond top-k
mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
mask := SliceAxis(maskIdx, -1, int32(k), int32(logits.Dim(-1)))
Free(maskIdx)
inf := FromValue(float32(math.Inf(-1)))
res := PutAlongAxis(logits, mask, inf, -1)
Free(mask, inf)
return res
}
// TopP implements nucleus (top-p) sampling.
@ -79,6 +93,7 @@ func (p TopP) Sample(logits *Array) *Array {
// Sort descending via argsort of negated probs
neg := Negative(probs)
sortIdx := Argsort(neg, -1)
Free(neg)
sortedProbs := TakeAlongAxis(probs, sortIdx, -1)
// Cumulative sum of sorted probabilities
@ -87,26 +102,26 @@ func (p TopP) Sample(logits *Array) *Array {
// Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold
shiftedCum := Subtract(cumProbs, sortedProbs)
threshold := FromValue(float32(p))
sortedMask := Where(
Greater(shiftedCum, threshold),
FromValue(float32(math.Inf(-1))),
FromValue(float32(0)),
)
inf := FromValue(float32(math.Inf(-1)))
zero := FromValue(float32(0))
gt := Greater(shiftedCum, threshold)
sortedMask := Where(gt, inf, zero)
Free(gt, inf, zero, threshold, shiftedCum, cumProbs, sortedProbs)
// Scatter mask back to original positions
mask := PutAlongAxis(
Zeros(logits.Shape(), DTypeFloat32),
sortIdx,
sortedMask,
-1,
)
emptyMask := Zeros(logits.Shape(), DTypeFloat32)
mask := PutAlongAxis(emptyMask, sortIdx, sortedMask, -1)
Free(emptyMask, sortIdx, sortedMask)
// Apply mask: -inf where excluded, original logit where kept
return Where(
Greater(FromValue(float32(0)), mask),
FromValue(float32(math.Inf(-1))),
logits,
)
zeroArr := FromValue(float32(0))
gt0 := Greater(zeroArr, mask)
inf2 := FromValue(float32(math.Inf(-1)))
res := Where(gt0, inf2, logits)
Free(zeroArr, gt0, inf2, mask, probs)
return res
}
// MinPSampler masks tokens whose probability is below min_p * max_prob.
@ -121,12 +136,12 @@ func (p MinPSampler) Sample(logits *Array) *Array {
// Threshold = min_p * max_prob
threshold := MulScalar(maxProb, float32(p))
Free(maxProb)
// Mask tokens below threshold
mask := Where(
Greater(threshold, probs),
FromValue(float32(math.Inf(-1))),
logits,
)
inf := FromValue(float32(math.Inf(-1)))
gt := Greater(threshold, probs)
mask := Where(gt, inf, logits)
Free(probs, threshold, inf, gt)
return mask
}