Cover options (GenerateConfig defaults, all With* options, ApplyGenerateOpts/ ApplyLoadOpts), backend registry (Register, Get, List, Default priority order metal > rocm > llama_cpp), LoadModel routing (explicit/auto backend, error paths), and Discover (model directory scanning, quantisation, edge cases). 69 tests, 100% statement coverage, race-clean. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
456 lines
11 KiB
Go
456 lines
11 KiB
Go
package inference
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"iter"
|
|
"sort"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// --- test helpers ---
|
|
|
|
// resetBackends clears the global registry for test isolation.
|
|
func resetBackends(t *testing.T) {
|
|
t.Helper()
|
|
backendsMu.Lock()
|
|
defer backendsMu.Unlock()
|
|
backends = map[string]Backend{}
|
|
}
|
|
|
|
// stubBackend is a minimal Backend implementation for testing.
|
|
type stubBackend struct {
|
|
name string
|
|
available bool
|
|
loadErr error
|
|
}
|
|
|
|
func (s *stubBackend) Name() string { return s.name }
|
|
func (s *stubBackend) Available() bool { return s.available }
|
|
func (s *stubBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
|
|
if s.loadErr != nil {
|
|
return nil, s.loadErr
|
|
}
|
|
return &stubTextModel{backend: s.name, path: path}, nil
|
|
}
|
|
|
|
// stubTextModel is a minimal TextModel for testing LoadModel routing.
|
|
type stubTextModel struct {
|
|
backend string
|
|
path string
|
|
}
|
|
|
|
func (m *stubTextModel) Generate(_ context.Context, _ string, _ ...GenerateOption) iter.Seq[Token] {
|
|
return func(yield func(Token) bool) {}
|
|
}
|
|
func (m *stubTextModel) Chat(_ context.Context, _ []Message, _ ...GenerateOption) iter.Seq[Token] {
|
|
return func(yield func(Token) bool) {}
|
|
}
|
|
func (m *stubTextModel) Classify(_ context.Context, _ []string, _ ...GenerateOption) ([]ClassifyResult, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *stubTextModel) BatchGenerate(_ context.Context, _ []string, _ ...GenerateOption) ([]BatchResult, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *stubTextModel) ModelType() string { return "stub" }
|
|
func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} }
|
|
func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} }
|
|
func (m *stubTextModel) Err() error { return nil }
|
|
func (m *stubTextModel) Close() error { return nil }
|
|
|
|
// --- Register ---
|
|
|
|
func TestRegister_Good(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
b := &stubBackend{name: "test_backend", available: true}
|
|
Register(b)
|
|
|
|
got, ok := Get("test_backend")
|
|
require.True(t, ok)
|
|
assert.Equal(t, "test_backend", got.Name())
|
|
}
|
|
|
|
func TestRegister_Good_Multiple(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "alpha", available: true})
|
|
Register(&stubBackend{name: "beta", available: true})
|
|
Register(&stubBackend{name: "gamma", available: true})
|
|
|
|
names := List()
|
|
sort.Strings(names)
|
|
assert.Equal(t, []string{"alpha", "beta", "gamma"}, names)
|
|
}
|
|
|
|
func TestRegister_Ugly_Overwrites(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
b1 := &stubBackend{name: "dup", available: false}
|
|
b2 := &stubBackend{name: "dup", available: true}
|
|
|
|
Register(b1)
|
|
Register(b2)
|
|
|
|
got, ok := Get("dup")
|
|
require.True(t, ok)
|
|
assert.True(t, got.Available(), "second registration should overwrite the first")
|
|
}
|
|
|
|
// --- Get ---
|
|
|
|
func TestGet_Good(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "exists", available: true})
|
|
|
|
b, ok := Get("exists")
|
|
require.True(t, ok)
|
|
assert.Equal(t, "exists", b.Name())
|
|
}
|
|
|
|
func TestGet_Bad(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
b, ok := Get("nonexistent")
|
|
assert.False(t, ok)
|
|
assert.Nil(t, b)
|
|
}
|
|
|
|
// --- List ---
|
|
|
|
func TestList_Good_Empty(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
names := List()
|
|
assert.Empty(t, names)
|
|
}
|
|
|
|
func TestList_Good_Populated(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "a", available: true})
|
|
Register(&stubBackend{name: "b", available: true})
|
|
|
|
names := List()
|
|
sort.Strings(names)
|
|
assert.Equal(t, []string{"a", "b"}, names)
|
|
}
|
|
|
|
// --- Default ---
|
|
|
|
func TestDefault_Good_Metal(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "metal", available: true})
|
|
Register(&stubBackend{name: "rocm", available: true})
|
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
|
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "metal", b.Name(), "metal should be preferred when available")
|
|
}
|
|
|
|
func TestDefault_Good_Rocm(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "rocm", available: true})
|
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
|
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "rocm", b.Name(), "rocm should be preferred when metal is absent")
|
|
}
|
|
|
|
func TestDefault_Good_LlamaCpp(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
|
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "llama_cpp", b.Name(), "llama_cpp should be used as fallback")
|
|
}
|
|
|
|
func TestDefault_Good_PriorityOrder(t *testing.T) {
|
|
// Full priority test: metal > rocm > llama_cpp
|
|
tests := []struct {
|
|
name string
|
|
backends []stubBackend
|
|
want string
|
|
}{
|
|
{
|
|
name: "all_available_prefers_metal",
|
|
backends: []stubBackend{
|
|
{name: "llama_cpp", available: true},
|
|
{name: "rocm", available: true},
|
|
{name: "metal", available: true},
|
|
},
|
|
want: "metal",
|
|
},
|
|
{
|
|
name: "metal_unavailable_prefers_rocm",
|
|
backends: []stubBackend{
|
|
{name: "metal", available: false},
|
|
{name: "rocm", available: true},
|
|
{name: "llama_cpp", available: true},
|
|
},
|
|
want: "rocm",
|
|
},
|
|
{
|
|
name: "metal_rocm_unavailable_prefers_llama_cpp",
|
|
backends: []stubBackend{
|
|
{name: "metal", available: false},
|
|
{name: "rocm", available: false},
|
|
{name: "llama_cpp", available: true},
|
|
},
|
|
want: "llama_cpp",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
resetBackends(t)
|
|
for i := range tt.backends {
|
|
Register(&tt.backends[i])
|
|
}
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tt.want, b.Name())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestDefault_Good_FallbackToAny(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
// A custom backend that's not in the priority list.
|
|
Register(&stubBackend{name: "custom_gpu", available: true})
|
|
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "custom_gpu", b.Name(), "should fall back to any available backend")
|
|
}
|
|
|
|
func TestDefault_Bad_NoBackends(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
_, err := Default()
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no backends registered")
|
|
}
|
|
|
|
func TestDefault_Bad_NoneAvailable(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "metal", available: false})
|
|
Register(&stubBackend{name: "rocm", available: false})
|
|
|
|
_, err := Default()
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no backends registered")
|
|
}
|
|
|
|
func TestDefault_Ugly_SkipsUnavailablePreferred(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
// metal registered but not available; custom_gpu is available.
|
|
Register(&stubBackend{name: "metal", available: false})
|
|
Register(&stubBackend{name: "custom_gpu", available: true})
|
|
|
|
b, err := Default()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "custom_gpu", b.Name(), "should skip unavailable preferred and use any available")
|
|
}
|
|
|
|
// --- LoadModel ---
|
|
|
|
func TestLoadModel_Good_DefaultBackend(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "metal", available: true})
|
|
|
|
m, err := LoadModel("/path/to/model")
|
|
require.NoError(t, err)
|
|
require.NotNil(t, m)
|
|
|
|
sm := m.(*stubTextModel)
|
|
assert.Equal(t, "metal", sm.backend)
|
|
assert.Equal(t, "/path/to/model", sm.path)
|
|
require.NoError(t, m.Close())
|
|
}
|
|
|
|
func TestLoadModel_Good_ExplicitBackend(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "metal", available: true})
|
|
Register(&stubBackend{name: "rocm", available: true})
|
|
|
|
m, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
|
require.NoError(t, err)
|
|
require.NotNil(t, m)
|
|
|
|
sm := m.(*stubTextModel)
|
|
assert.Equal(t, "rocm", sm.backend, "should use explicitly requested backend")
|
|
require.NoError(t, m.Close())
|
|
}
|
|
|
|
func TestLoadModel_Bad_NoBackends(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
_, err := LoadModel("/path/to/model")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "no backends registered")
|
|
}
|
|
|
|
func TestLoadModel_Bad_ExplicitBackendNotRegistered(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "metal", available: true})
|
|
|
|
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "backend \"rocm\" not registered")
|
|
}
|
|
|
|
func TestLoadModel_Bad_ExplicitBackendNotAvailable(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{name: "rocm", available: false})
|
|
|
|
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "backend \"rocm\" not available")
|
|
}
|
|
|
|
func TestLoadModel_Bad_BackendLoadError(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{
|
|
name: "broken",
|
|
available: true,
|
|
loadErr: fmt.Errorf("GPU out of memory"),
|
|
})
|
|
|
|
_, err := LoadModel("/path/to/model", WithBackend("broken"))
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "GPU out of memory")
|
|
}
|
|
|
|
func TestLoadModel_Good_PassesOptionsThrough(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
// Use a backend that captures the load path.
|
|
Register(&stubBackend{name: "metal", available: true})
|
|
|
|
m, err := LoadModel("/models/gemma3-1b",
|
|
WithContextLen(4096),
|
|
WithGPULayers(24),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
sm := m.(*stubTextModel)
|
|
assert.Equal(t, "/models/gemma3-1b", sm.path)
|
|
require.NoError(t, m.Close())
|
|
}
|
|
|
|
func TestLoadModel_Ugly_DefaultBackendLoadError(t *testing.T) {
|
|
resetBackends(t)
|
|
|
|
Register(&stubBackend{
|
|
name: "metal",
|
|
available: true,
|
|
loadErr: fmt.Errorf("model not found"),
|
|
})
|
|
|
|
_, err := LoadModel("/nonexistent/model")
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "model not found")
|
|
}
|
|
|
|
// --- Type assertions (compile-time checks) ---
|
|
|
|
func TestTypes_Good_InterfaceCompliance(t *testing.T) {
|
|
// Verify stubBackend implements Backend.
|
|
var _ Backend = (*stubBackend)(nil)
|
|
// Verify stubTextModel implements TextModel.
|
|
var _ TextModel = (*stubTextModel)(nil)
|
|
}
|
|
|
|
// --- Struct types ---
|
|
|
|
func TestToken_Good(t *testing.T) {
|
|
tok := Token{ID: 42, Text: "hello"}
|
|
assert.Equal(t, int32(42), tok.ID)
|
|
assert.Equal(t, "hello", tok.Text)
|
|
}
|
|
|
|
func TestMessage_Good(t *testing.T) {
|
|
msg := Message{Role: "user", Content: "Hi there"}
|
|
assert.Equal(t, "user", msg.Role)
|
|
assert.Equal(t, "Hi there", msg.Content)
|
|
}
|
|
|
|
func TestClassifyResult_Good(t *testing.T) {
|
|
cr := ClassifyResult{
|
|
Token: Token{ID: 1, Text: "yes"},
|
|
Logits: []float32{0.1, 0.9},
|
|
}
|
|
assert.Equal(t, int32(1), cr.Token.ID)
|
|
assert.Equal(t, "yes", cr.Token.Text)
|
|
assert.Len(t, cr.Logits, 2)
|
|
}
|
|
|
|
func TestBatchResult_Good(t *testing.T) {
|
|
br := BatchResult{
|
|
Tokens: []Token{{ID: 1, Text: "a"}, {ID: 2, Text: "b"}},
|
|
Err: nil,
|
|
}
|
|
assert.Len(t, br.Tokens, 2)
|
|
assert.NoError(t, br.Err)
|
|
}
|
|
|
|
func TestBatchResult_Bad(t *testing.T) {
|
|
br := BatchResult{
|
|
Tokens: nil,
|
|
Err: fmt.Errorf("OOM"),
|
|
}
|
|
assert.Nil(t, br.Tokens)
|
|
assert.Error(t, br.Err)
|
|
}
|
|
|
|
func TestModelInfo_Good(t *testing.T) {
|
|
info := ModelInfo{
|
|
Architecture: "gemma3",
|
|
VocabSize: 256128,
|
|
NumLayers: 26,
|
|
HiddenSize: 1152,
|
|
QuantBits: 4,
|
|
QuantGroup: 64,
|
|
}
|
|
assert.Equal(t, "gemma3", info.Architecture)
|
|
assert.Equal(t, 256128, info.VocabSize)
|
|
assert.Equal(t, 26, info.NumLayers)
|
|
assert.Equal(t, 1152, info.HiddenSize)
|
|
assert.Equal(t, 4, info.QuantBits)
|
|
assert.Equal(t, 64, info.QuantGroup)
|
|
}
|
|
|
|
func TestGenerateMetrics_Good(t *testing.T) {
|
|
m := GenerateMetrics{
|
|
PromptTokens: 100,
|
|
GeneratedTokens: 50,
|
|
PrefillTokensPerSec: 1000.0,
|
|
DecodeTokensPerSec: 25.0,
|
|
PeakMemoryBytes: 1 << 30, // 1 GiB
|
|
ActiveMemoryBytes: 512 << 20, // 512 MiB
|
|
}
|
|
assert.Equal(t, 100, m.PromptTokens)
|
|
assert.Equal(t, 50, m.GeneratedTokens)
|
|
assert.InDelta(t, 1000.0, m.PrefillTokensPerSec, 0.01)
|
|
assert.InDelta(t, 25.0, m.DecodeTokensPerSec, 0.01)
|
|
assert.Equal(t, uint64(1<<30), m.PeakMemoryBytes)
|
|
assert.Equal(t, uint64(512<<20), m.ActiveMemoryBytes)
|
|
}
|