feat: add InspectAttention pass-through on InferenceAdapter
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
1c5fd160f2
commit
45e9fed280
2 changed files with 59 additions and 0 deletions
12
adapter.go
12
adapter.go
|
|
@ -4,6 +4,7 @@ package ml
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
|
|
@ -97,6 +98,17 @@ func (a *InferenceAdapter) Close() error { return a.model.Close() }
|
|||
// to Classify, BatchGenerate, Metrics, Info, etc.
|
||||
func (a *InferenceAdapter) Model() inference.TextModel { return a.model }
|
||||
|
||||
// InspectAttention delegates to the underlying TextModel if it implements
|
||||
// inference.AttentionInspector. Returns an error if the backend does not support
|
||||
// attention inspection.
|
||||
func (a *InferenceAdapter) InspectAttention(ctx context.Context, prompt string, opts ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
|
||||
inspector, ok := a.model.(inference.AttentionInspector)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("backend %q does not support attention inspection", a.name)
|
||||
}
|
||||
return inspector.InspectAttention(ctx, prompt, opts...)
|
||||
}
|
||||
|
||||
// convertOpts maps ml.GenOpts to go-inference functional options.
|
||||
func convertOpts(opts GenOpts) []inference.GenerateOption {
|
||||
var out []inference.GenerateOption
|
||||
|
|
|
|||
47
adapter_attention_test.go
Normal file
47
adapter_attention_test.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
// SPDX-Licence-Identifier: EUPL-1.2
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockAttentionModel extends mockTextModel with AttentionInspector support.
|
||||
type mockAttentionModel struct {
|
||||
mockTextModel
|
||||
}
|
||||
|
||||
func (m *mockAttentionModel) InspectAttention(_ context.Context, _ string, _ ...inference.GenerateOption) (*inference.AttentionSnapshot, error) {
|
||||
return &inference.AttentionSnapshot{
|
||||
NumLayers: 28,
|
||||
NumHeads: 8,
|
||||
SeqLen: 10,
|
||||
HeadDim: 64,
|
||||
Architecture: "qwen3",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func TestInferenceAdapter_InspectAttention_Good(t *testing.T) {
|
||||
adapter := NewInferenceAdapter(&mockAttentionModel{}, "test")
|
||||
snap, err := adapter.InspectAttention(context.Background(), "hello")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 28, snap.NumLayers)
|
||||
assert.Equal(t, 8, snap.NumHeads)
|
||||
assert.Equal(t, 10, snap.SeqLen)
|
||||
assert.Equal(t, 64, snap.HeadDim)
|
||||
assert.Equal(t, "qwen3", snap.Architecture)
|
||||
}
|
||||
|
||||
func TestInferenceAdapter_InspectAttention_Unsupported_Bad(t *testing.T) {
|
||||
// Plain mockTextModel does not implement AttentionInspector.
|
||||
adapter := NewInferenceAdapter(&mockTextModel{}, "test")
|
||||
_, err := adapter.InspectAttention(context.Background(), "hello")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "does not support attention inspection")
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue