feat: add InspectAttention pass-through on InferenceAdapter

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 00:26:26 +00:00
parent 1c5fd160f2
commit 45e9fed280
2 changed files with 59 additions and 0 deletions

View file

@ -4,6 +4,7 @@ package ml
import (
"context"
"fmt"
"strings"
"forge.lthn.ai/core/go-inference"
@ -97,6 +98,17 @@ func (a *InferenceAdapter) Close() error { return a.model.Close() }
// to Classify, BatchGenerate, Metrics, Info, etc.
func (a *InferenceAdapter) Model() inference.TextModel { return a.model }
// InspectAttention delegates to the underlying TextModel if it implements
// inference.AttentionInspector. Returns an error if the backend does not support
// attention inspection.
func (a *InferenceAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
inspector, ok := a.model.(inference.AttentionInspector)
if !ok {
return nil, fmt.Errorf("backend %q does not support attention inspection", a.name)
}
return inspector.InspectAttention(ctx, prompt, opts...)
}
// convertOpts maps ml.GenOpts to go-inference functional options.
func convertOpts(opts GenOpts) []inference.GenerateOption {
var out []inference.GenerateOption

47
adapter_attention_test.go Normal file
View file

@ -0,0 +1,47 @@
// SPDX-Licence-Identifier: EUPL-1.2
package ml
import (
"context"
"testing"
"forge.lthn.ai/core/go-inference"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockAttentionModel extends mockTextModel with AttentionInspector support.
type mockAttentionModel struct {
mockTextModel
}
func (m *mockAttentionModel) InspectAttention(_ context.Context, _ string, _ ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
return &inference.AttentionSnapshot{
NumLayers: 28,
NumHeads: 8,
SeqLen: 10,
HeadDim: 64,
Architecture: "qwen3",
}, nil
}
func TestInferenceAdapter_InspectAttention_Good(t *testing.T) {
adapter := NewInferenceAdapter(&mockAttentionModel{}, "test")
snap, err := adapter.InspectAttention(context.Background(), "hello")
require.NoError(t, err)
assert.Equal(t, 28, snap.NumLayers)
assert.Equal(t, 8, snap.NumHeads)
assert.Equal(t, 10, snap.SeqLen)
assert.Equal(t, 64, snap.HeadDim)
assert.Equal(t, "qwen3", snap.Architecture)
}
func TestInferenceAdapter_InspectAttention_Unsupported_Bad(t *testing.T) {
// Plain mockTextModel does not implement AttentionInspector.
adapter := NewInferenceAdapter(&mockTextModel{}, "test")
_, err := adapter.InspectAttention(context.Background(), "hello")
require.Error(t, err)
assert.Contains(t, err.Error(), "does not support attention inspection")
}