feat: add AttentionInspector optional interface for Q/K Bone Orientation
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
bd09ec4997
commit
0f7263f731
2 changed files with 57 additions and 0 deletions
18
inference.go
18
inference.go
|
|
@ -126,6 +126,24 @@ type ModelInfo struct {
|
|||
QuantGroup int // Quantisation group size (0 if unquantised)
|
||||
}
|
||||
|
||||
// AttentionSnapshot holds K vectors extracted from the KV cache after prefill.
|
||||
// Keys is indexed [layer][head][position*head_dim] — flattened per head.
|
||||
type AttentionSnapshot struct {
|
||||
NumLayers int `json:"num_layers"`
|
||||
NumHeads int `json:"num_heads"` // num_kv_heads (may differ from query heads in GQA)
|
||||
SeqLen int `json:"seq_len"` // number of tokens in the prompt
|
||||
HeadDim int `json:"head_dim"`
|
||||
Keys [][][]float32 `json:"keys"` // [layer][head] → flat float32 of len seq_len*head_dim
|
||||
Architecture string `json:"architecture"`
|
||||
}
|
||||
|
||||
// AttentionInspector is an optional interface that backends may implement
|
||||
// to expose attention-level data for Q/K Bone Orientation analysis.
|
||||
// Use type assertion: if inspector, ok := model.(AttentionInspector); ok { ... }
|
||||
type AttentionInspector interface {
|
||||
InspectAttention(ctx context.Context, prompt string, opts ...GenerateOption) (*AttentionSnapshot, error)
|
||||
}
|
||||
|
||||
// TextModel generates text from a loaded model.
|
||||
type TextModel interface {
|
||||
// Generate streams tokens for the given prompt.
|
||||
|
|
|
|||
|
|
@ -394,6 +394,45 @@ func TestTypes_Good_InterfaceCompliance(t *testing.T) {
|
|||
var _ TextModel = (*stubTextModel)(nil)
|
||||
}
|
||||
|
||||
// --- AttentionSnapshot ---
|
||||
|
||||
func TestAttentionSnapshot_Good(t *testing.T) {
|
||||
snap := AttentionSnapshot{
|
||||
NumLayers: 28,
|
||||
NumHeads: 16,
|
||||
SeqLen: 42,
|
||||
HeadDim: 64,
|
||||
Keys: make([][][]float32, 28),
|
||||
Architecture: "gemma3",
|
||||
}
|
||||
assert.Equal(t, 28, snap.NumLayers)
|
||||
assert.Equal(t, 16, snap.NumHeads)
|
||||
assert.Equal(t, 42, snap.SeqLen)
|
||||
assert.Equal(t, 64, snap.HeadDim)
|
||||
assert.Len(t, snap.Keys, 28)
|
||||
assert.Equal(t, "gemma3", snap.Architecture)
|
||||
}
|
||||
|
||||
func TestAttentionInspector_Good_InterfaceCompliance(t *testing.T) {
|
||||
// Compile-time check: the interface exists and has the right signature.
|
||||
var _ AttentionInspector = (*mockInspector)(nil)
|
||||
}
|
||||
|
||||
type mockInspector struct{ stubTextModel }
|
||||
|
||||
func (m *mockInspector) InspectAttention(_ context.Context, _ string, _ ...GenerateOption) (*AttentionSnapshot, error) {
|
||||
return &AttentionSnapshot{NumLayers: 28, NumHeads: 8, SeqLen: 10, HeadDim: 64, Architecture: "qwen3"}, nil
|
||||
}
|
||||
|
||||
func TestAttentionInspector_Good_ReturnsSnapshot(t *testing.T) {
|
||||
var inspector AttentionInspector = &mockInspector{}
|
||||
snap, err := inspector.InspectAttention(context.Background(), "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 28, snap.NumLayers)
|
||||
assert.Equal(t, 8, snap.NumHeads)
|
||||
assert.Equal(t, "qwen3", snap.Architecture)
|
||||
}
|
||||
|
||||
// --- Struct types ---
|
||||
|
||||
func TestToken_Good(t *testing.T) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue