diff --git a/adapter.go b/adapter.go index 8f0ba61..08ac13b 100644 --- a/adapter.go +++ b/adapter.go @@ -4,6 +4,7 @@ package ml import ( "context" + "fmt" "strings" "forge.lthn.ai/core/go-inference" @@ -97,6 +98,17 @@ func (a *InferenceAdapter) Close() error { return a.model.Close() } // to Classify, BatchGenerate, Metrics, Info, etc. func (a *InferenceAdapter) Model() inference.TextModel { return a.model } +// InspectAttention delegates to the underlying TextModel if it implements +// inference.AttentionInspector. Returns an error if the backend does not support +// attention inspection. +func (a *InferenceAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) { + inspector, ok := a.model.(inference.AttentionInspector) + if !ok { + return nil, fmt.Errorf("backend %q does not support attention inspection", a.name) + } + return inspector.InspectAttention(ctx, prompt, opts...) +} + // convertOpts maps ml.GenOpts to go-inference functional options. func convertOpts(opts GenOpts) []inference.GenerateOption { var out []inference.GenerateOption diff --git a/adapter_attention_test.go b/adapter_attention_test.go new file mode 100644 index 0000000..9970752 --- /dev/null +++ b/adapter_attention_test.go @@ -0,0 +1,47 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ml + +import ( + "context" + "testing" + + "forge.lthn.ai/core/go-inference" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockAttentionModel extends mockTextModel with AttentionInspector support. +type mockAttentionModel struct { + mockTextModel +} + +func (m *mockAttentionModel) InspectAttention(_ context.Context, _ string, _ ...inference.GenerateOption) (*inference.AttentionSnapshot, error) { + return &inference.AttentionSnapshot{ + NumLayers: 28, + NumHeads: 8, + SeqLen: 10, + HeadDim: 64, + Architecture: "qwen3", + }, nil +} + +func TestInferenceAdapter_InspectAttention_Good(t *testing.T) { + adapter := NewInferenceAdapter(&mockAttentionModel{}, "test") + snap, err := adapter.InspectAttention(context.Background(), "hello") + require.NoError(t, err) + assert.Equal(t, 28, snap.NumLayers) + assert.Equal(t, 8, snap.NumHeads) + assert.Equal(t, 10, snap.SeqLen) + assert.Equal(t, 64, snap.HeadDim) + assert.Equal(t, "qwen3", snap.Architecture) +} + +func TestInferenceAdapter_InspectAttention_Unsupported_Bad(t *testing.T) { + // Plain mockTextModel does not implement AttentionInspector. + adapter := NewInferenceAdapter(&mockTextModel{}, "test") + _, err := adapter.InspectAttention(context.Background(), "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "does not support attention inspection") +}