From 0f7263f731e7d090b2dbc2a4b7af6cbdc0eeeff6 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 23 Feb 2026 00:22:51 +0000 Subject: [PATCH] feat: add AttentionInspector optional interface for Q/K Bone Orientation Co-Authored-By: Virgil --- inference.go | 18 ++++++++++++++++++ inference_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/inference.go b/inference.go index d695f3d..ef0c78c 100644 --- a/inference.go +++ b/inference.go @@ -126,6 +126,24 @@ type ModelInfo struct { QuantGroup int // Quantisation group size (0 if unquantised) } +// AttentionSnapshot holds K vectors extracted from the KV cache after prefill. +// Keys is indexed [layer][head][position*head_dim] — flattened per head. +type AttentionSnapshot struct { + NumLayers int `json:"num_layers"` + NumHeads int `json:"num_heads"` // num_kv_heads (may differ from query heads in GQA) + SeqLen int `json:"seq_len"` // number of tokens in the prompt + HeadDim int `json:"head_dim"` + Keys [][][]float32 `json:"keys"` // [layer][head] → flat float32 of len seq_len*head_dim + Architecture string `json:"architecture"` +} + +// AttentionInspector is an optional interface that backends may implement +// to expose attention-level data for Q/K Bone Orientation analysis. +// Use type assertion: if inspector, ok := model.(AttentionInspector); ok { ... } +type AttentionInspector interface { + InspectAttention(ctx context.Context, prompt string, opts ...GenerateOption) (*AttentionSnapshot, error) +} + // TextModel generates text from a loaded model. type TextModel interface { // Generate streams tokens for the given prompt. diff --git a/inference_test.go b/inference_test.go index 31c2168..829ef01 100644 --- a/inference_test.go +++ b/inference_test.go @@ -394,6 +394,45 @@ func TestTypes_Good_InterfaceCompliance(t *testing.T) { var _ TextModel = (*stubTextModel)(nil) } +// --- AttentionSnapshot --- + +func TestAttentionSnapshot_Good(t *testing.T) { + snap := AttentionSnapshot{ + NumLayers: 28, + NumHeads: 16, + SeqLen: 42, + HeadDim: 64, + Keys: make([][][]float32, 28), + Architecture: "gemma3", + } + assert.Equal(t, 28, snap.NumLayers) + assert.Equal(t, 16, snap.NumHeads) + assert.Equal(t, 42, snap.SeqLen) + assert.Equal(t, 64, snap.HeadDim) + assert.Len(t, snap.Keys, 28) + assert.Equal(t, "gemma3", snap.Architecture) +} + +func TestAttentionInspector_Good_InterfaceCompliance(t *testing.T) { + // Compile-time check: the interface exists and has the right signature. + var _ AttentionInspector = (*mockInspector)(nil) +} + +type mockInspector struct{ stubTextModel } + +func (m *mockInspector) InspectAttention(_ context.Context, _ string, _ ...GenerateOption) (*AttentionSnapshot, error) { + return &AttentionSnapshot{NumLayers: 28, NumHeads: 8, SeqLen: 10, HeadDim: 64, Architecture: "qwen3"}, nil +} + +func TestAttentionInspector_Good_ReturnsSnapshot(t *testing.T) { + var inspector AttentionInspector = &mockInspector{} + snap, err := inspector.InspectAttention(context.Background(), "hello") + require.NoError(t, err) + assert.Equal(t, 28, snap.NumLayers) + assert.Equal(t, 8, snap.NumHeads) + assert.Equal(t, "qwen3", snap.Architecture) +} + // --- Struct types --- func TestToken_Good(t *testing.T) {