diff --git a/attention_test.go b/attention_test.go new file mode 100644 index 0000000..15faa4a --- /dev/null +++ b/attention_test.go @@ -0,0 +1,69 @@ +//go:build darwin && arm64 + +package mlx_test + +import ( + "context" + "testing" + + "forge.lthn.ai/core/go-inference" + mlx "forge.lthn.ai/core/go-mlx" +) + +func TestMetalAdapterImplementsAttentionInspector(t *testing.T) { + // Load a real model and verify the adapter implements AttentionInspector. + b, ok := inference.Get("metal") + if !ok { + t.Fatal("metal backend not registered") + } + + modelPath := gemma3ModelPath(t) + m, err := b.LoadModel(modelPath) + if err != nil { + t.Fatalf("LoadModel: %v", err) + } + defer func() { m.Close(); mlx.ClearCache() }() + + inspector, ok := m.(inference.AttentionInspector) + if !ok { + t.Fatal("metalAdapter does not implement AttentionInspector") + } + + ctx := context.Background() + snap, err := inspector.InspectAttention(ctx, "What is kindness?") + if err != nil { + t.Fatalf("InspectAttention: %v", err) + } + + if snap.NumLayers == 0 { + t.Error("NumLayers should be > 0") + } + if snap.NumHeads == 0 { + t.Error("NumHeads should be > 0") + } + if snap.SeqLen == 0 { + t.Error("SeqLen should be > 0") + } + if snap.HeadDim == 0 { + t.Error("HeadDim should be > 0") + } + if snap.Architecture == "" { + t.Error("Architecture should not be empty") + } + if len(snap.Keys) != snap.NumLayers { + t.Errorf("Keys len = %d, want %d (NumLayers)", len(snap.Keys), snap.NumLayers) + } + + // Verify at least the first layer has data + if len(snap.Keys[0]) != snap.NumHeads { + t.Errorf("Keys[0] len = %d, want %d (NumHeads)", len(snap.Keys[0]), snap.NumHeads) + } + + expectedLen := snap.SeqLen * snap.HeadDim + if len(snap.Keys[0][0]) != expectedLen { + t.Errorf("Keys[0][0] len = %d, want %d (SeqLen*HeadDim)", len(snap.Keys[0][0]), expectedLen) + } + + t.Logf("AttentionSnapshot: arch=%s layers=%d heads=%d seq=%d dim=%d", + snap.Architecture, snap.NumLayers, snap.NumHeads, snap.SeqLen, snap.HeadDim) +} diff --git a/internal/metal/generate.go b/internal/metal/generate.go index 8ec2322..e55ae27 100644 --- a/internal/metal/generate.go +++ b/internal/metal/generate.go @@ -229,6 +229,87 @@ func (m *Model) Generate(ctx context.Context, prompt string, cfg GenerateConfig) } } +// InspectAttention runs a single prefill pass and extracts K vectors from each layer's KV cache. +// Returns the post-RoPE K tensors as CPU float32 slices indexed [layer][head][seq_len*head_dim]. +func (m *Model) InspectAttention(ctx context.Context, prompt string) (*AttentionResult, error) { + tokens := m.tokenizer.Encode(prompt) + if len(tokens) == 0 { + return nil, fmt.Errorf("empty prompt after tokenisation") + } + + caches := m.newCaches() + + // Single prefill pass — populates KV caches for all layers. + input := FromValues(tokens, len(tokens)) + input = Reshape(input, 1, int32(len(tokens))) + logits := m.model.Forward(input, caches) + if err := Eval(logits); err != nil { + return nil, fmt.Errorf("prefill: %w", err) + } + + info := m.Info() + seqLen := len(tokens) + + // Extract K vectors from each layer's cache. + keys := make([][][]float32, info.NumLayers) + var numHeads, headDim int + + for i := 0; i < info.NumLayers && i < len(caches); i++ { + state := caches[i].State() + if len(state) < 1 { + continue + } + kArray := state[0] // K tensor from cache: [B, H, L_alloc, D] + shape := kArray.Shape() + if len(shape) != 4 { + continue + } + numHeads = int(shape[1]) + headDim = int(shape[3]) + + // Slice to valid tokens only (cache may have pre-allocated padding). + validLen := min(caches[i].Len(), seqLen) + kSliced := Slice(kArray, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], int32(validLen), shape[3]}) + if err := Eval(kSliced); err != nil { + continue + } + + // Extract all floats then reshape per head. + flat := kSliced.Floats() // len = 1 * H * validLen * D + keys[i] = make([][]float32, numHeads) + stride := validLen * headDim + for h := 0; h < numHeads; h++ { + start := h * stride + end := start + stride + if end > len(flat) { + break + } + head := make([]float32, stride) + copy(head, flat[start:end]) + keys[i][h] = head + } + } + + return &AttentionResult{ + NumLayers: info.NumLayers, + NumHeads: numHeads, + SeqLen: seqLen, + HeadDim: headDim, + Keys: keys, + Architecture: info.Architecture, + }, nil +} + +// AttentionResult holds extracted K vectors from the KV cache. +type AttentionResult struct { + NumLayers int + NumHeads int + SeqLen int + HeadDim int + Keys [][][]float32 // [layer][head] → flat float32 of len seq_len*head_dim + Architecture string +} + // applyRepeatPenalty modifies logits to discourage repeated tokens. // For each unique token ID in history: positive logits are divided by penalty, // negative logits are multiplied by penalty. Both make the token less likely. diff --git a/register_metal.go b/register_metal.go index b742bb2..55accd4 100644 --- a/register_metal.go +++ b/register_metal.go @@ -192,5 +192,21 @@ func (a *metalAdapter) Info() inference.ModelInfo { QuantGroup: i.QuantGroup, } } +// InspectAttention implements inference.AttentionInspector. +func (a *metalAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) { + result, err := a.m.InspectAttention(ctx, prompt) + if err != nil { + return nil, err + } + return &inference.AttentionSnapshot{ + NumLayers: result.NumLayers, + NumHeads: result.NumHeads, + SeqLen: result.SeqLen, + HeadDim: result.HeadDim, + Keys: result.Keys, + Architecture: result.Architecture, + }, nil +} + func (a *metalAdapter) Err() error { return a.m.Err() } func (a *metalAdapter) Close() error { return a.m.Close() }