feat: implement AttentionInspector via KV cache extraction after prefill
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
9a48774538
commit
c2177f754a
3 changed files with 166 additions and 0 deletions
69
attention_test.go
Normal file
69
attention_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
//go:build darwin && arm64
|
||||
|
||||
package mlx_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
mlx "forge.lthn.ai/core/go-mlx"
|
||||
)
|
||||
|
||||
func TestMetalAdapterImplementsAttentionInspector(t *testing.T) {
|
||||
// Load a real model and verify the adapter implements AttentionInspector.
|
||||
b, ok := inference.Get("metal")
|
||||
if !ok {
|
||||
t.Fatal("metal backend not registered")
|
||||
}
|
||||
|
||||
modelPath := gemma3ModelPath(t)
|
||||
m, err := b.LoadModel(modelPath)
|
||||
if err != nil {
|
||||
t.Fatalf("LoadModel: %v", err)
|
||||
}
|
||||
defer func() { m.Close(); mlx.ClearCache() }()
|
||||
|
||||
inspector, ok := m.(inference.AttentionInspector)
|
||||
if !ok {
|
||||
t.Fatal("metalAdapter does not implement AttentionInspector")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
snap, err := inspector.InspectAttention(ctx, "What is kindness?")
|
||||
if err != nil {
|
||||
t.Fatalf("InspectAttention: %v", err)
|
||||
}
|
||||
|
||||
if snap.NumLayers == 0 {
|
||||
t.Error("NumLayers should be > 0")
|
||||
}
|
||||
if snap.NumHeads == 0 {
|
||||
t.Error("NumHeads should be > 0")
|
||||
}
|
||||
if snap.SeqLen == 0 {
|
||||
t.Error("SeqLen should be > 0")
|
||||
}
|
||||
if snap.HeadDim == 0 {
|
||||
t.Error("HeadDim should be > 0")
|
||||
}
|
||||
if snap.Architecture == "" {
|
||||
t.Error("Architecture should not be empty")
|
||||
}
|
||||
if len(snap.Keys) != snap.NumLayers {
|
||||
t.Errorf("Keys len = %d, want %d (NumLayers)", len(snap.Keys), snap.NumLayers)
|
||||
}
|
||||
|
||||
// Verify at least the first layer has data
|
||||
if len(snap.Keys[0]) != snap.NumHeads {
|
||||
t.Errorf("Keys[0] len = %d, want %d (NumHeads)", len(snap.Keys[0]), snap.NumHeads)
|
||||
}
|
||||
|
||||
expectedLen := snap.SeqLen * snap.HeadDim
|
||||
if len(snap.Keys[0][0]) != expectedLen {
|
||||
t.Errorf("Keys[0][0] len = %d, want %d (SeqLen*HeadDim)", len(snap.Keys[0][0]), expectedLen)
|
||||
}
|
||||
|
||||
t.Logf("AttentionSnapshot: arch=%s layers=%d heads=%d seq=%d dim=%d",
|
||||
snap.Architecture, snap.NumLayers, snap.NumHeads, snap.SeqLen, snap.HeadDim)
|
||||
}
|
||||
|
|
@ -229,6 +229,87 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig)
|
|||
}
|
||||
}
|
||||
|
||||
// InspectAttention runs a single prefill pass and extracts K vectors from each layer's KV cache.
|
||||
// Returns the post-RoPE K tensors as CPU float32 slices indexed [layer][head][seq_len*head_dim].
|
||||
func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) {
|
||||
tokens := m.tokenizer.Encode(prompt)
|
||||
if len(tokens) == 0 {
|
||||
return nil, fmt.Errorf("empty prompt after tokenisation")
|
||||
}
|
||||
|
||||
caches := m.newCaches()
|
||||
|
||||
// Single prefill pass — populates KV caches for all layers.
|
||||
input := FromValues(tokens, len(tokens))
|
||||
input = Reshape(input, 1, int32(len(tokens)))
|
||||
logits := m.model.Forward(input, caches)
|
||||
if err := Eval(logits); err != nil {
|
||||
return nil, fmt.Errorf("prefill: %w", err)
|
||||
}
|
||||
|
||||
info := m.Info()
|
||||
seqLen := len(tokens)
|
||||
|
||||
// Extract K vectors from each layer's cache.
|
||||
keys := make([][][]float32, info.NumLayers)
|
||||
var numHeads, headDim int
|
||||
|
||||
for i := 0; i < info.NumLayers && i < len(caches); i++ {
|
||||
state := caches[i].State()
|
||||
if len(state) < 1 {
|
||||
continue
|
||||
}
|
||||
kArray := state[0] // K tensor from cache: [B, H, L_alloc, D]
|
||||
shape := kArray.Shape()
|
||||
if len(shape) != 4 {
|
||||
continue
|
||||
}
|
||||
numHeads = int(shape[1])
|
||||
headDim = int(shape[3])
|
||||
|
||||
// Slice to valid tokens only (cache may have pre-allocated padding).
|
||||
validLen := min(caches[i].Len(), seqLen)
|
||||
kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]})
|
||||
if err := Eval(kSliced); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract all floats then reshape per head.
|
||||
flat := kSliced.Floats() // len = 1 * H * validLen * D
|
||||
keys[i] = make([][]float32, numHeads)
|
||||
stride := validLen * headDim
|
||||
for h := 0; h < numHeads; h++ {
|
||||
start := h * stride
|
||||
end := start + stride
|
||||
if end > len(flat) {
|
||||
break
|
||||
}
|
||||
head := make([]float32, stride)
|
||||
copy(head, flat[start:end])
|
||||
keys[i][h] = head
|
||||
}
|
||||
}
|
||||
|
||||
return &AttentionResult{
|
||||
NumLayers: info.NumLayers,
|
||||
NumHeads: numHeads,
|
||||
SeqLen: seqLen,
|
||||
HeadDim: headDim,
|
||||
Keys: keys,
|
||||
Architecture: info.Architecture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AttentionResult holds extracted K vectors from the KV cache.
|
||||
type AttentionResult struct {
|
||||
NumLayers int
|
||||
NumHeads int
|
||||
SeqLen int
|
||||
HeadDim int
|
||||
Keys [][][]float32 // [layer][head] → flat float32 of len seq_len*head_dim
|
||||
Architecture string
|
||||
}
|
||||
|
||||
// applyRepeatPenalty modifies logits to discourage repeated tokens.
|
||||
// For each unique token ID in history: positive logits are divided by penalty,
|
||||
// negative logits are multiplied by penalty. Both make the token less likely.
|
||||
|
|
|
|||
|
|
@ -192,5 +192,21 @@ func (a *metalAdapter) Info() inference.ModelInfo {
|
|||
QuantGroup: i.QuantGroup,
|
||||
}
|
||||
}
|
||||
// InspectAttention implements inference.AttentionInspector.
|
||||
func (a *metalAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
|
||||
result, err := a.m.InspectAttention(ctx, prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inference.AttentionSnapshot{
|
||||
NumLayers: result.NumLayers,
|
||||
NumHeads: result.NumHeads,
|
||||
SeqLen: result.SeqLen,
|
||||
HeadDim: result.HeadDim,
|
||||
Keys: result.Keys,
|
||||
Architecture: result.Architecture,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *metalAdapter) Err() error { return a.m.Err() }
|
||||
func (a *metalAdapter) Close() error { return a.m.Close() }
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue