fix: add deterministic GPU memory cleanup across inference paths
- defer freeCaches() in Generate and InspectAttention - Free orphaned arrays during KVCache growth and slice updates - Free per-token scalar intermediates in samplers and ops - Free intermediate arrays in applyRepeatPenalty Found by 3-way review: Claude explorer, Codex (gpt-5.3), Gemini Ultra. Gemini implemented the fixes. Co-Authored-By: Gemini <noreply@google.com> Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
208f76b067
commit
51ac442a09
4 changed files with 109 additions and 49 deletions
|
|
@ -50,20 +50,25 @@ func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) {
|
|||
newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
|
||||
|
||||
if c.keys != nil {
|
||||
oldK, oldV := c.keys, c.values
|
||||
if prev%c.step != 0 {
|
||||
c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
oldK = Slice(oldK, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
|
||||
oldV = Slice(oldV, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
|
||||
Free(c.keys, c.values)
|
||||
}
|
||||
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, newV}, 2)
|
||||
c.keys = Concatenate([]*Array{oldK, newK}, 2)
|
||||
c.values = Concatenate([]*Array{oldV, newV}, 2)
|
||||
Free(oldK, oldV, newK, newV)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
}
|
||||
|
||||
c.offset += seqLen
|
||||
oldK, oldV := c.keys, c.values
|
||||
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
|
||||
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
Free(oldK, oldV)
|
||||
|
||||
return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
|
||||
Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
|
||||
|
|
@ -127,8 +132,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
|
|||
newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
|
||||
newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
|
||||
if c.keys != nil {
|
||||
c.keys = Concatenate([]*Array{c.keys, newK}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, newV}, 2)
|
||||
oldK, oldV := c.keys, c.values
|
||||
c.keys = Concatenate([]*Array{oldK, newK}, 2)
|
||||
c.values = Concatenate([]*Array{oldV, newV}, 2)
|
||||
Free(oldK, oldV, newK, newV)
|
||||
} else {
|
||||
c.keys, c.values = newK, newV
|
||||
}
|
||||
|
|
@ -138,8 +145,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) {
|
|||
c.idx = 0
|
||||
}
|
||||
|
||||
oldK, oldV := c.keys, c.values
|
||||
c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
|
||||
c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
|
||||
Free(oldK, oldV)
|
||||
|
||||
c.offset++
|
||||
c.idx++
|
||||
|
|
@ -163,17 +172,21 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array)
|
|||
Dv := v.Shape()[3]
|
||||
|
||||
if c.keys == nil {
|
||||
c.keys, c.values = k, v
|
||||
c.keys, c.values = k.Clone(), v.Clone()
|
||||
} else {
|
||||
c.keys = Concatenate([]*Array{c.keys, k}, 2)
|
||||
c.values = Concatenate([]*Array{c.values, v}, 2)
|
||||
oldK, oldV := c.keys, c.values
|
||||
c.keys = Concatenate([]*Array{oldK, k}, 2)
|
||||
c.values = Concatenate([]*Array{oldV, v}, 2)
|
||||
Free(oldK, oldV)
|
||||
}
|
||||
c.offset += seqLen
|
||||
|
||||
cap := int(c.keys.Shape()[2])
|
||||
if trim := cap - c.maxSize; trim > 0 {
|
||||
oldK, oldV := c.keys, c.values
|
||||
c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
|
||||
c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
|
||||
Free(oldK, oldV)
|
||||
}
|
||||
|
||||
c.idx = int(c.keys.Shape()[2])
|
||||
|
|
|
|||
|
|
@ -139,6 +139,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
tokens := m.tokenizer.Encode(prompt)
|
||||
promptLen := len(tokens)
|
||||
caches := m.newCaches()
|
||||
defer freeCaches(caches)
|
||||
|
||||
sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK)
|
||||
var genCount int
|
||||
var prefillDur time.Duration
|
||||
|
|
@ -165,9 +167,11 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
|
||||
// Prefill: process entire prompt
|
||||
prefillStart := time.Now()
|
||||
input := FromValues(tokens, len(tokens))
|
||||
input = Reshape(input, 1, int32(len(tokens)))
|
||||
vInput := FromValues(tokens, len(tokens))
|
||||
input := Reshape(vInput, 1, int32(len(tokens)))
|
||||
logits := m.model.Forward(input, caches)
|
||||
Free(vInput, input)
|
||||
|
||||
if err := Eval(logits); err != nil {
|
||||
m.lastErr = fmt.Errorf("prefill: %w", err)
|
||||
return
|
||||
|
|
@ -177,6 +181,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
// Track generated token IDs for repeat penalty.
|
||||
var history []int32
|
||||
|
||||
defer func() {
|
||||
Free(logits)
|
||||
}()
|
||||
|
||||
for i := 0; i < cfg.MaxTokens; i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -186,41 +194,55 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
}
|
||||
|
||||
// Sample from last position logits
|
||||
lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
|
||||
lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2)))
|
||||
l1 := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1)))
|
||||
lastPos := Reshape(l1, 1, int32(l1.Dim(2)))
|
||||
Free(l1)
|
||||
|
||||
// Apply repeat penalty before sampling.
|
||||
if cfg.RepeatPenalty > 1.0 && len(history) > 0 {
|
||||
oldLastPos := lastPos
|
||||
lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty)
|
||||
Free(oldLastPos)
|
||||
}
|
||||
|
||||
next := sampler.Sample(lastPos)
|
||||
if err := Eval(next); err != nil {
|
||||
m.lastErr = fmt.Errorf("sample step %d: %w", i, err)
|
||||
Free(lastPos, next)
|
||||
return
|
||||
}
|
||||
|
||||
id := int32(next.Int())
|
||||
history = append(history, id)
|
||||
Free(lastPos)
|
||||
|
||||
// Check stop conditions
|
||||
if id == m.tokenizer.EOSToken() {
|
||||
Free(next)
|
||||
return
|
||||
}
|
||||
if slices.Contains(cfg.StopTokens, id) {
|
||||
Free(next)
|
||||
return
|
||||
}
|
||||
|
||||
genCount++
|
||||
text := m.tokenizer.DecodeToken(id)
|
||||
if !yield(Token{ID: id, Text: text}) {
|
||||
Free(next)
|
||||
return
|
||||
}
|
||||
Free(next)
|
||||
|
||||
// Next step input
|
||||
nextInput := FromValues([]int32{id}, 1)
|
||||
nextInput = Reshape(nextInput, 1, 1)
|
||||
vNextInput := FromValues([]int32{id}, 1)
|
||||
nextInput := Reshape(vNextInput, 1, 1)
|
||||
Free(vNextInput)
|
||||
|
||||
oldLogits := logits
|
||||
logits = m.model.Forward(nextInput, caches)
|
||||
Free(nextInput, oldLogits)
|
||||
|
||||
if err := Eval(logits); err != nil {
|
||||
m.lastErr = fmt.Errorf("decode step %d: %w", i, err)
|
||||
return
|
||||
|
|
@ -238,6 +260,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
|
|||
}
|
||||
|
||||
caches := m.newCaches()
|
||||
defer freeCaches(caches)
|
||||
|
||||
// Single prefill pass — populates KV caches for all layers.
|
||||
input := FromValues(tokens, len(tokens))
|
||||
|
|
@ -271,11 +294,14 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention
|
|||
validLen := min(caches[i].Len(), seqLen)
|
||||
kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]})
|
||||
if err := Eval(kSliced); err != nil {
|
||||
Free(kSliced)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract all floats then reshape per head.
|
||||
flat := kSliced.Floats() // len = 1 * H * validLen * D
|
||||
Free(kSliced)
|
||||
|
||||
keys[i] = make([][]float32, numHeads)
|
||||
stride := validLen * headDim
|
||||
for h := 0; h < numHeads; h++ {
|
||||
|
|
@ -332,13 +358,15 @@ func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array
|
|||
penaltyVal := FromValue(penalty)
|
||||
|
||||
// Positive logits: divide by penalty. Negative logits: multiply by penalty.
|
||||
penalised := Where(
|
||||
Greater(gathered, zero),
|
||||
Mul(gathered, invPenalty),
|
||||
Mul(gathered, penaltyVal),
|
||||
)
|
||||
gt := Greater(gathered, zero)
|
||||
m1 := Mul(gathered, invPenalty)
|
||||
m2 := Mul(gathered, penaltyVal)
|
||||
penalised := Where(gt, m1, m2)
|
||||
Free(gt, m1, m2)
|
||||
|
||||
return PutAlongAxis(logits, idx, penalised, -1)
|
||||
res := PutAlongAxis(logits, idx, penalised, -1)
|
||||
Free(idx, gathered, zero, invPenalty, penaltyVal, penalised)
|
||||
return res
|
||||
}
|
||||
|
||||
// newCaches creates per-layer KV caches. If contextLen is set, all unbounded
|
||||
|
|
|
|||
|
|
@ -22,7 +22,9 @@ func Add(a, b *Array) *Array {
|
|||
// AddScalar returns a + scalar (broadcast).
|
||||
func AddScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Add(a, scalar)
|
||||
res := Add(a, scalar)
|
||||
Free(scalar)
|
||||
return res
|
||||
}
|
||||
|
||||
// Mul returns element-wise a * b.
|
||||
|
|
@ -35,7 +37,9 @@ func Mul(a, b *Array) *Array {
|
|||
// MulScalar returns a * scalar (broadcast).
|
||||
func MulScalar(a *Array, s float32) *Array {
|
||||
scalar := FromValue(s)
|
||||
return Mul(a, scalar)
|
||||
res := Mul(a, scalar)
|
||||
Free(scalar)
|
||||
return res
|
||||
}
|
||||
|
||||
// Divide returns element-wise a / b.
|
||||
|
|
|
|||
|
|
@ -36,11 +36,20 @@ func newSampler(temp, topP, minP float32, topK int) Sampler {
|
|||
type chain []Sampler
|
||||
|
||||
func (c chain) Sample(logits *Array) *Array {
|
||||
curr := logits
|
||||
for _, s := range c {
|
||||
logits = s.Sample(logits)
|
||||
next := s.Sample(curr)
|
||||
if curr != logits {
|
||||
Free(curr)
|
||||
}
|
||||
curr = next
|
||||
}
|
||||
// Final categorical sample from log-probabilities
|
||||
return RandomCategorical(logits)
|
||||
res := RandomCategorical(curr)
|
||||
if curr != logits {
|
||||
Free(curr)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// greedy returns the argmax token.
|
||||
|
|
@ -62,10 +71,15 @@ type TopKSampler int
|
|||
|
||||
func (k TopKSampler) Sample(logits *Array) *Array {
|
||||
neg := Negative(logits)
|
||||
mask := Argpartition(neg, int(k)-1, -1)
|
||||
maskIdx := Argpartition(neg, int(k)-1, -1)
|
||||
Free(neg)
|
||||
// Slice the indices beyond top-k
|
||||
mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1)))
|
||||
return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1)
|
||||
mask := SliceAxis(maskIdx, -1, int32(k), int32(logits.Dim(-1)))
|
||||
Free(maskIdx)
|
||||
inf := FromValue(float32(math.Inf(-1)))
|
||||
res := PutAlongAxis(logits, mask, inf, -1)
|
||||
Free(mask, inf)
|
||||
return res
|
||||
}
|
||||
|
||||
// TopP implements nucleus (top-p) sampling.
|
||||
|
|
@ -79,6 +93,7 @@ func (p TopP) Sample(logits *Array) *Array {
|
|||
// Sort descending via argsort of negated probs
|
||||
neg := Negative(probs)
|
||||
sortIdx := Argsort(neg, -1)
|
||||
Free(neg)
|
||||
sortedProbs := TakeAlongAxis(probs, sortIdx, -1)
|
||||
|
||||
// Cumulative sum of sorted probabilities
|
||||
|
|
@ -87,26 +102,26 @@ func (p TopP) Sample(logits *Array) *Array {
|
|||
// Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold
|
||||
shiftedCum := Subtract(cumProbs, sortedProbs)
|
||||
threshold := FromValue(float32(p))
|
||||
sortedMask := Where(
|
||||
Greater(shiftedCum, threshold),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
FromValue(float32(0)),
|
||||
)
|
||||
inf := FromValue(float32(math.Inf(-1)))
|
||||
zero := FromValue(float32(0))
|
||||
|
||||
gt := Greater(shiftedCum, threshold)
|
||||
sortedMask := Where(gt, inf, zero)
|
||||
Free(gt, inf, zero, threshold, shiftedCum, cumProbs, sortedProbs)
|
||||
|
||||
// Scatter mask back to original positions
|
||||
mask := PutAlongAxis(
|
||||
Zeros(logits.Shape(), DTypeFloat32),
|
||||
sortIdx,
|
||||
sortedMask,
|
||||
-1,
|
||||
)
|
||||
emptyMask := Zeros(logits.Shape(), DTypeFloat32)
|
||||
mask := PutAlongAxis(emptyMask, sortIdx, sortedMask, -1)
|
||||
Free(emptyMask, sortIdx, sortedMask)
|
||||
|
||||
// Apply mask: -inf where excluded, original logit where kept
|
||||
return Where(
|
||||
Greater(FromValue(float32(0)), mask),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
zeroArr := FromValue(float32(0))
|
||||
gt0 := Greater(zeroArr, mask)
|
||||
inf2 := FromValue(float32(math.Inf(-1)))
|
||||
res := Where(gt0, inf2, logits)
|
||||
Free(zeroArr, gt0, inf2, mask, probs)
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// MinPSampler masks tokens whose probability is below min_p * max_prob.
|
||||
|
|
@ -121,12 +136,12 @@ func (p MinPSampler) Sample(logits *Array) *Array {
|
|||
|
||||
// Threshold = min_p * max_prob
|
||||
threshold := MulScalar(maxProb, float32(p))
|
||||
Free(maxProb)
|
||||
|
||||
// Mask tokens below threshold
|
||||
mask := Where(
|
||||
Greater(threshold, probs),
|
||||
FromValue(float32(math.Inf(-1))),
|
||||
logits,
|
||||
)
|
||||
inf := FromValue(float32(math.Inf(-1)))
|
||||
gt := Greater(threshold, probs)
|
||||
mask := Where(gt, inf, logits)
|
||||
Free(probs, threshold, inf, gt)
|
||||
return mask
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue