diff --git a/internal/metal/cache.go b/internal/metal/cache.go index c2630d8..9f5f581 100644 --- a/internal/metal/cache.go +++ b/internal/metal/cache.go @@ -50,20 +50,25 @@ func (c *KVCache) Update(k, v *Array, seqLen int) (*Array, *Array) { newV := Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype()) if c.keys != nil { + oldK, oldV := c.keys, c.values if prev%c.step != 0 { - c.keys = Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) - c.values = Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) + oldK = Slice(oldK, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk}) + oldV = Slice(oldV, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv}) + Free(c.keys, c.values) } - c.keys = Concatenate([]*Array{c.keys, newK}, 2) - c.values = Concatenate([]*Array{c.values, newV}, 2) + c.keys = Concatenate([]*Array{oldK, newK}, 2) + c.values = Concatenate([]*Array{oldV, newV}, 2) + Free(oldK, oldV, newK, newV) } else { c.keys, c.values = newK, newV } } c.offset += seqLen + oldK, oldV := c.keys, c.values c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk}) c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv}) + Free(oldK, oldV) return Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}), Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv}) @@ -127,8 +132,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) { newK := Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype()) newV := Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype()) if c.keys != nil { - c.keys = Concatenate([]*Array{c.keys, newK}, 2) - c.values = Concatenate([]*Array{c.values, newV}, 2) + oldK, oldV := c.keys, c.values + c.keys = Concatenate([]*Array{oldK, newK}, 2) + c.values = Concatenate([]*Array{oldV, newV}, 2) + Free(oldK, oldV, newK, newV) } else { c.keys, c.values = newK, newV } @@ -138,8 +145,10 @@ func (c *RotatingKVCache) updateInPlace(k, v *Array) (*Array, *Array) { c.idx = 0 } + oldK, oldV := c.keys, c.values c.keys = SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk}) c.values = SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv}) + Free(oldK, oldV) c.offset++ c.idx++ @@ -163,17 +172,21 @@ func (c *RotatingKVCache) updateConcat(k, v *Array, seqLen int) (*Array, *Array) Dv := v.Shape()[3] if c.keys == nil { - c.keys, c.values = k, v + c.keys, c.values = k.Clone(), v.Clone() } else { - c.keys = Concatenate([]*Array{c.keys, k}, 2) - c.values = Concatenate([]*Array{c.values, v}, 2) + oldK, oldV := c.keys, c.values + c.keys = Concatenate([]*Array{oldK, k}, 2) + c.values = Concatenate([]*Array{oldV, v}, 2) + Free(oldK, oldV) } c.offset += seqLen cap := int(c.keys.Shape()[2]) if trim := cap - c.maxSize; trim > 0 { + oldK, oldV := c.keys, c.values c.keys = Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk}) c.values = Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv}) + Free(oldK, oldV) } c.idx = int(c.keys.Shape()[2]) diff --git a/internal/metal/generate.go b/internal/metal/generate.go index e55ae27..342d095 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -139,6 +139,8 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) tokens := m.tokenizer.Encode(prompt) promptLen := len(tokens) caches := m.newCaches() + defer freeCaches(caches) + sampler := newSampler(cfg.Temperature, cfg.TopP, 0, cfg.TopK) var genCount int var prefillDur time.Duration @@ -165,9 +167,11 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) // Prefill: process entire prompt prefillStart := time.Now() - input := FromValues(tokens, len(tokens)) - input = Reshape(input, 1, int32(len(tokens))) + vInput := FromValues(tokens, len(tokens)) + input := Reshape(vInput, 1, int32(len(tokens))) logits := m.model.Forward(input, caches) + Free(vInput, input) + if err := Eval(logits); err != nil { m.lastErr = fmt.Errorf("prefill: %w", err) return @@ -177,6 +181,10 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) // Track generated token IDs for repeat penalty. var history []int32 + defer func() { + Free(logits) + }() + for i := 0; i < cfg.MaxTokens; i++ { select { case <-ctx.Done(): @@ -186,41 +194,55 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) } // Sample from last position logits - lastPos := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1))) - lastPos = Reshape(lastPos, 1, int32(lastPos.Dim(2))) + l1 := SliceAxis(logits, 1, int32(logits.Dim(1)-1), int32(logits.Dim(1))) + lastPos := Reshape(l1, 1, int32(l1.Dim(2))) + Free(l1) // Apply repeat penalty before sampling. if cfg.RepeatPenalty > 1.0 && len(history) > 0 { + oldLastPos := lastPos lastPos = applyRepeatPenalty(lastPos, history, cfg.RepeatPenalty) + Free(oldLastPos) } next := sampler.Sample(lastPos) if err := Eval(next); err != nil { m.lastErr = fmt.Errorf("sample step %d: %w", i, err) + Free(lastPos, next) return } id := int32(next.Int()) history = append(history, id) + Free(lastPos) // Check stop conditions if id == m.tokenizer.EOSToken() { + Free(next) return } if slices.Contains(cfg.StopTokens, id) { + Free(next) return } genCount++ text := m.tokenizer.DecodeToken(id) if !yield(Token{ID: id, Text: text}) { + Free(next) return } + Free(next) // Next step input - nextInput := FromValues([]int32{id}, 1) - nextInput = Reshape(nextInput, 1, 1) + vNextInput := FromValues([]int32{id}, 1) + nextInput := Reshape(vNextInput, 1, 1) + Free(vNextInput) + + oldLogits := logits logits = m.model.Forward(nextInput, caches) + Free(nextInput, oldLogits) + if err := Eval(logits); err != nil { m.lastErr = fmt.Errorf("decode step %d: %w", i, err) return @@ -238,6 +260,7 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention } caches := m.newCaches() + defer freeCaches(caches) // Single prefill pass — populates KV caches for all layers. input := FromValues(tokens, len(tokens)) @@ -271,11 +294,14 @@ func (m *Model) InspectAttention(ctx context.Context, prompt string) (*Attention validLen := min(caches[i].Len(), seqLen) kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]}) if err := Eval(kSliced); err != nil { + Free(kSliced) continue } // Extract all floats then reshape per head. flat := kSliced.Floats() // len = 1 * H * validLen * D + Free(kSliced) + keys[i] = make([][]float32, numHeads) stride := validLen * headDim for h := 0; h < numHeads; h++ { @@ -332,13 +358,15 @@ func applyRepeatPenalty(logits *Array, history []int32, penalty float32) *Array penaltyVal := FromValue(penalty) // Positive logits: divide by penalty. Negative logits: multiply by penalty. - penalised := Where( - Greater(gathered, zero), - Mul(gathered, invPenalty), - Mul(gathered, penaltyVal), - ) + gt := Greater(gathered, zero) + m1 := Mul(gathered, invPenalty) + m2 := Mul(gathered, penaltyVal) + penalised := Where(gt, m1, m2) + Free(gt, m1, m2) - return PutAlongAxis(logits, idx, penalised, -1) + res := PutAlongAxis(logits, idx, penalised, -1) + Free(idx, gathered, zero, invPenalty, penaltyVal, penalised) + return res } // newCaches creates per-layer KV caches. If contextLen is set, all unbounded diff --git a/internal/metal/ops.go b/internal/metal/ops.go index 70470b5..53795d2 100644 --- a/internal/metal/ops.go +++ b/internal/metal/ops.go @@ -22,7 +22,9 @@ func Add(a, b *Array) *Array { // AddScalar returns a + scalar (broadcast). func AddScalar(a *Array, s float32) *Array { scalar := FromValue(s) - return Add(a, scalar) + res := Add(a, scalar) + Free(scalar) + return res } // Mul returns element-wise a * b. @@ -35,7 +37,9 @@ func Mul(a, b *Array) *Array { // MulScalar returns a * scalar (broadcast). func MulScalar(a *Array, s float32) *Array { scalar := FromValue(s) - return Mul(a, scalar) + res := Mul(a, scalar) + Free(scalar) + return res } // Divide returns element-wise a / b. diff --git a/internal/metal/sample.go b/internal/metal/sample.go index 1d29160..6420d28 100644 --- a/internal/metal/sample.go +++ b/internal/metal/sample.go @@ -36,11 +36,20 @@ func newSampler(temp, topP, minP float32, topK int) Sampler { type chain []Sampler func (c chain) Sample(logits *Array) *Array { + curr := logits for _, s := range c { - logits = s.Sample(logits) + next := s.Sample(curr) + if curr != logits { + Free(curr) + } + curr = next } // Final categorical sample from log-probabilities - return RandomCategorical(logits) + res := RandomCategorical(curr) + if curr != logits { + Free(curr) + } + return res } // greedy returns the argmax token. @@ -62,10 +71,15 @@ type TopKSampler int func (k TopKSampler) Sample(logits *Array) *Array { neg := Negative(logits) - mask := Argpartition(neg, int(k)-1, -1) + maskIdx := Argpartition(neg, int(k)-1, -1) + Free(neg) // Slice the indices beyond top-k - mask = SliceAxis(mask, -1, int32(k), int32(logits.Dim(-1))) - return PutAlongAxis(logits, mask, FromValue(float32(math.Inf(-1))), -1) + mask := SliceAxis(maskIdx, -1, int32(k), int32(logits.Dim(-1))) + Free(maskIdx) + inf := FromValue(float32(math.Inf(-1))) + res := PutAlongAxis(logits, mask, inf, -1) + Free(mask, inf) + return res } // TopP implements nucleus (top-p) sampling. @@ -79,6 +93,7 @@ func (p TopP) Sample(logits *Array) *Array { // Sort descending via argsort of negated probs neg := Negative(probs) sortIdx := Argsort(neg, -1) + Free(neg) sortedProbs := TakeAlongAxis(probs, sortIdx, -1) // Cumulative sum of sorted probabilities @@ -87,26 +102,26 @@ func (p TopP) Sample(logits *Array) *Array { // Mask in sorted space: keep tokens where cumprob (excluding current) <= threshold shiftedCum := Subtract(cumProbs, sortedProbs) threshold := FromValue(float32(p)) - sortedMask := Where( - Greater(shiftedCum, threshold), - FromValue(float32(math.Inf(-1))), - FromValue(float32(0)), - ) + inf := FromValue(float32(math.Inf(-1))) + zero := FromValue(float32(0)) + + gt := Greater(shiftedCum, threshold) + sortedMask := Where(gt, inf, zero) + Free(gt, inf, zero, threshold, shiftedCum, cumProbs, sortedProbs) // Scatter mask back to original positions - mask := PutAlongAxis( - Zeros(logits.Shape(), DTypeFloat32), - sortIdx, - sortedMask, - -1, - ) + emptyMask := Zeros(logits.Shape(), DTypeFloat32) + mask := PutAlongAxis(emptyMask, sortIdx, sortedMask, -1) + Free(emptyMask, sortIdx, sortedMask) // Apply mask: -inf where excluded, original logit where kept - return Where( - Greater(FromValue(float32(0)), mask), - FromValue(float32(math.Inf(-1))), - logits, - ) + zeroArr := FromValue(float32(0)) + gt0 := Greater(zeroArr, mask) + inf2 := FromValue(float32(math.Inf(-1))) + res := Where(gt0, inf2, logits) + Free(zeroArr, gt0, inf2, mask, probs) + + return res } // MinPSampler masks tokens whose probability is below min_p * max_prob. @@ -121,12 +136,12 @@ func (p MinPSampler) Sample(logits *Array) *Array { // Threshold = min_p * max_prob threshold := MulScalar(maxProb, float32(p)) + Free(maxProb) // Mask tokens below threshold - mask := Where( - Greater(threshold, probs), - FromValue(float32(math.Inf(-1))), - logits, - ) + inf := FromValue(float32(math.Inf(-1))) + gt := Greater(threshold, probs) + mask := Where(gt, inf, logits) + Free(probs, threshold, inf, gt) return mask }