feat: implement AttentionInspector via KV cache extraction after prefill

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 00:25:09 +00:00
parent 9a48774538
commit c2177f754a
3 changed files with 166 additions and 0 deletions

69
attention_test.go Normal file
View file

@ -0,0 +1,69 @@
//go:build darwin && arm64
package mlx_test
import (
"context"
"testing"
"forge.lthn.ai/core/go-inference"
mlx "forge.lthn.ai/core/go-mlx"
)
func TestMetalAdapterImplementsAttentionInspector(t *testing.T) {
// Load a real model and verify the adapter implements AttentionInspector.
b, ok := inference.Get("metal")
if !ok {
t.Fatal("metal backend not registered")
}
modelPath := gemma3ModelPath(t)
m, err := b.LoadModel(modelPath)
if err != nil {
t.Fatalf("LoadModel: %v", err)
}
defer func() { m.Close(); mlx.ClearCache() }()
inspector, ok := m.(inference.AttentionInspector)
if !ok {
t.Fatal("metalAdapter does not implement AttentionInspector")
}
ctx := context.Background()
snap, err := inspector.InspectAttention(ctx, "What is kindness?")
if err != nil {
t.Fatalf("InspectAttention: %v", err)
}
if snap.NumLayers == 0 {
t.Error("NumLayers should be > 0")
}
if snap.NumHeads == 0 {
t.Error("NumHeads should be > 0")
}
if snap.SeqLen == 0 {
t.Error("SeqLen should be > 0")
}
if snap.HeadDim == 0 {
t.Error("HeadDim should be > 0")
}
if snap.Architecture == "" {
t.Error("Architecture should not be empty")
}
if len(snap.Keys) != snap.NumLayers {
t.Errorf("Keys len = %d, want %d (NumLayers)", len(snap.Keys), snap.NumLayers)
}
// Verify at least the first layer has data
if len(snap.Keys[0]) != snap.NumHeads {
t.Errorf("Keys[0] len = %d, want %d (NumHeads)", len(snap.Keys[0]), snap.NumHeads)
}
expectedLen := snap.SeqLen * snap.HeadDim
if len(snap.Keys[0][0]) != expectedLen {
t.Errorf("Keys[0][0] len = %d, want %d (SeqLen*HeadDim)", len(snap.Keys[0][0]), expectedLen)
}
t.Logf("AttentionSnapshot: arch=%s layers=%d heads=%d seq=%d dim=%d",
snap.Architecture, snap.NumLayers, snap.NumHeads, snap.SeqLen, snap.HeadDim)
}

View file

@ -229,6 +229,87 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
}
}
// InspectAttention runs a single prefill pass and extracts K vectors from each layer's KV cache.
// Returns the post-RoPE K tensors as CPU float32 slices indexed [layer][head][seq_len*head_dim].
func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) {
tokens := m.tokenizer.Encode(prompt)
if len(tokens) == 0 {
return nil, fmt.Errorf("empty prompt after tokenisation")
}
caches := m.newCaches()
// Single prefill pass — populates KV caches for all layers.
input := FromValues(tokens, len(tokens))
input = Reshape(input, 1, int32(len(tokens)))
logits := m.model.Forward(input, caches)
if err := Eval(logits); err != nil {
return nil, fmt.Errorf("prefill: %w", err)
}
info := m.Info()
seqLen := len(tokens)
// Extract K vectors from each layer's cache.
keys := make([][][]float32, info.NumLayers)
var numHeads, headDim int
for i := 0; i < info.NumLayers && i < len(caches); i++ {
state := caches[i].State()
if len(state) < 1 {
continue
}
kArray := state[0] // K tensor from cache: [B, H, L_alloc, D]
shape := kArray.Shape()
if len(shape) != 4 {
continue
}
numHeads = int(shape[1])
headDim = int(shape[3])
// Slice to valid tokens only (cache may have pre-allocated padding).
validLen := min(caches[i].Len(), seqLen)
kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]})
if err := Eval(kSliced); err != nil {
continue
}
// Extract all floats then reshape per head.
flat := kSliced.Floats() // len = 1 * H * validLen * D
keys[i] = make([][]float32, numHeads)
stride := validLen * headDim
for h := 0; h < numHeads; h++ {
start := h * stride
end := start + stride
if end > len(flat) {
break
}
head := make([]float32, stride)
copy(head, flat[start:end])
keys[i][h] = head
}
}
return &AttentionResult{
NumLayers: info.NumLayers,
NumHeads: numHeads,
SeqLen: seqLen,
HeadDim: headDim,
Keys: keys,
Architecture: info.Architecture,
}, nil
}
// AttentionResult holds extracted K vectors from the KV cache.
type AttentionResult struct {
NumLayers int
NumHeads int
SeqLen int
HeadDim int
Keys [][][]float32 // [layer][head] → flat float32 of len seq_len*head_dim
Architecture string
}
// applyRepeatPenalty modifies logits to discourage repeated tokens.
// For each unique token ID in history: positive logits are divided by penalty,
// negative logits are multiplied by penalty. Both make the token less likely.

View file

@ -192,5 +192,21 @@ func (a *metalAdapter) Info() inference.ModelInfo {
QuantGroup: i.QuantGroup,
}
}
// InspectAttention implements inference.AttentionInspector.
func (a *metalAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
result, err := a.m.InspectAttention(ctx, prompt)
if err != nil {
return nil, err
}
return &inference.AttentionSnapshot{
NumLayers: result.NumLayers,
NumHeads: result.NumHeads,
SeqLen: result.SeqLen,
HeadDim: result.HeadDim,
Keys: result.Keys,
Architecture: result.Architecture,
}, nil
}
func (a *metalAdapter) Err() error { return a.m.Err() }
func (a *metalAdapter) Close() error { return a.m.Close() }