69 lines
1.8 KiB
Go
69 lines
1.8 KiB
Go
//go:build darwin && arm64
|
|
|
|
package mlx_test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"forge.lthn.ai/core/go-inference"
|
|
mlx "forge.lthn.ai/core/go-mlx"
|
|
)
|
|
|
|
func TestMetalAdapterImplementsAttentionInspector(t *testing.T) {
|
|
// Load a real model and verify the adapter implements AttentionInspector.
|
|
b, ok := inference.Get("metal")
|
|
if !ok {
|
|
t.Fatal("metal backend not registered")
|
|
}
|
|
|
|
modelPath := gemma3ModelPath(t)
|
|
m, err := b.LoadModel(modelPath)
|
|
if err != nil {
|
|
t.Fatalf("LoadModel: %v", err)
|
|
}
|
|
defer func() { m.Close(); mlx.ClearCache() }()
|
|
|
|
inspector, ok := m.(inference.AttentionInspector)
|
|
if !ok {
|
|
t.Fatal("metalAdapter does not implement AttentionInspector")
|
|
}
|
|
|
|
ctx := context.Background()
|
|
snap, err := inspector.InspectAttention(ctx, "What is kindness?")
|
|
if err != nil {
|
|
t.Fatalf("InspectAttention: %v", err)
|
|
}
|
|
|
|
if snap.NumLayers == 0 {
|
|
t.Error("NumLayers should be > 0")
|
|
}
|
|
if snap.NumHeads == 0 {
|
|
t.Error("NumHeads should be > 0")
|
|
}
|
|
if snap.SeqLen == 0 {
|
|
t.Error("SeqLen should be > 0")
|
|
}
|
|
if snap.HeadDim == 0 {
|
|
t.Error("HeadDim should be > 0")
|
|
}
|
|
if snap.Architecture == "" {
|
|
t.Error("Architecture should not be empty")
|
|
}
|
|
if len(snap.Keys) != snap.NumLayers {
|
|
t.Errorf("Keys len = %d, want %d (NumLayers)", len(snap.Keys), snap.NumLayers)
|
|
}
|
|
|
|
// Verify at least the first layer has data
|
|
if len(snap.Keys[0]) != snap.NumHeads {
|
|
t.Errorf("Keys[0] len = %d, want %d (NumHeads)", len(snap.Keys[0]), snap.NumHeads)
|
|
}
|
|
|
|
expectedLen := snap.SeqLen * snap.HeadDim
|
|
if len(snap.Keys[0][0]) != expectedLen {
|
|
t.Errorf("Keys[0][0] len = %d, want %d (SeqLen*HeadDim)", len(snap.Keys[0][0]), expectedLen)
|
|
}
|
|
|
|
t.Logf("AttentionSnapshot: arch=%s layers=%d heads=%d seq=%d dim=%d",
|
|
snap.Architecture, snap.NumLayers, snap.NumHeads, snap.SeqLen, snap.HeadDim)
|
|
}
|