go-inference/inference_test.go
Claude e0ec07e667
test(inference): complete Phase 1 foundation tests
Add comprehensive tests for all three Phase 1 items:

- Option application: DefaultGenerateConfig idempotency, field isolation
  (WithMaxTokens leaves others at defaults), bad-input acceptance
  (negative temperature, negative TopK), empty variadic StopTokens,
  WithLogits default-is-false, partial-options preserve defaults,
  last-wins overrides for all GenerateOption and LoadOption types.

- Backend registry: concurrent read/write safety (70 goroutines with
  -race), overwrite-keeps-count, capturingBackend verifies LoadModel
  forwards all options to both explicit and default backends, Get
  after overwrite returns latest, List returns independent slices.

- Default() platform preference: registration order is irrelevant
  (metal wins regardless), all-preferred-unavailable falls back to
  custom, multiple custom backends finds the available one, empty
  path forwarding.

85 tests, 100% statement coverage, -race clean.

Co-Authored-By: Charon <developers@lethean.io>
2026-02-20 11:45:59 +00:00

696 lines
18 KiB
Go

package inference
import (
"context"
"fmt"
"iter"
"sort"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- test helpers ---
// resetBackends clears the global registry for test isolation.
func resetBackends(t *testing.T) {
t.Helper()
backendsMu.Lock()
defer backendsMu.Unlock()
backends = map[string]Backend{}
}
// stubBackend is a minimal Backend implementation for testing.
type stubBackend struct {
name string
available bool
loadErr error
}
func (s *stubBackend) Name() string { return s.name }
func (s *stubBackend) Available() bool { return s.available }
func (s *stubBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
if s.loadErr != nil {
return nil, s.loadErr
}
return &stubTextModel{backend: s.name, path: path}, nil
}
// capturingBackend records the LoadOption values it received.
type capturingBackend struct {
name string
available bool
capturedOpts []LoadOption
}
func (c *capturingBackend) Name() string { return c.name }
func (c *capturingBackend) Available() bool { return c.available }
func (c *capturingBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
c.capturedOpts = opts
return &stubTextModel{backend: c.name, path: path}, nil
}
// stubTextModel is a minimal TextModel for testing LoadModel routing.
type stubTextModel struct {
backend string
path string
}
func (m *stubTextModel) Generate(_ context.Context, _ string, _ ...GenerateOption) iter.Seq[Token] {
return func(yield func(Token) bool) {}
}
func (m *stubTextModel) Chat(_ context.Context, _ []Message, _ ...GenerateOption) iter.Seq[Token] {
return func(yield func(Token) bool) {}
}
func (m *stubTextModel) Classify(_ context.Context, _ []string, _ ...GenerateOption) ([]ClassifyResult, error) {
return nil, nil
}
func (m *stubTextModel) BatchGenerate(_ context.Context, _ []string, _ ...GenerateOption) ([]BatchResult, error) {
return nil, nil
}
func (m *stubTextModel) ModelType() string { return "stub" }
func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} }
func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} }
func (m *stubTextModel) Err() error { return nil }
func (m *stubTextModel) Close() error { return nil }
// --- Register ---
func TestRegister_Good(t *testing.T) {
resetBackends(t)
b := &stubBackend{name: "test_backend", available: true}
Register(b)
got, ok := Get("test_backend")
require.True(t, ok)
assert.Equal(t, "test_backend", got.Name())
}
func TestRegister_Good_Multiple(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "alpha", available: true})
Register(&stubBackend{name: "beta", available: true})
Register(&stubBackend{name: "gamma", available: true})
names := List()
sort.Strings(names)
assert.Equal(t, []string{"alpha", "beta", "gamma"}, names)
}
func TestRegister_Ugly_Overwrites(t *testing.T) {
resetBackends(t)
b1 := &stubBackend{name: "dup", available: false}
b2 := &stubBackend{name: "dup", available: true}
Register(b1)
Register(b2)
got, ok := Get("dup")
require.True(t, ok)
assert.True(t, got.Available(), "second registration should overwrite the first")
}
// --- Get ---
func TestGet_Good(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "exists", available: true})
b, ok := Get("exists")
require.True(t, ok)
assert.Equal(t, "exists", b.Name())
}
func TestGet_Bad(t *testing.T) {
resetBackends(t)
b, ok := Get("nonexistent")
assert.False(t, ok)
assert.Nil(t, b)
}
// --- List ---
func TestList_Good_Empty(t *testing.T) {
resetBackends(t)
names := List()
assert.Empty(t, names)
}
func TestList_Good_Populated(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "a", available: true})
Register(&stubBackend{name: "b", available: true})
names := List()
sort.Strings(names)
assert.Equal(t, []string{"a", "b"}, names)
}
// --- Default ---
func TestDefault_Good_Metal(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "metal", b.Name(), "metal should be preferred when available")
}
func TestDefault_Good_Rocm(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "rocm", b.Name(), "rocm should be preferred when metal is absent")
}
func TestDefault_Good_LlamaCpp(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "llama_cpp", b.Name(), "llama_cpp should be used as fallback")
}
func TestDefault_Good_PriorityOrder(t *testing.T) {
// Full priority test: metal > rocm > llama_cpp
tests := []struct {
name string
backends []stubBackend
want string
}{
{
name: "all_available_prefers_metal",
backends: []stubBackend{
{name: "llama_cpp", available: true},
{name: "rocm", available: true},
{name: "metal", available: true},
},
want: "metal",
},
{
name: "metal_unavailable_prefers_rocm",
backends: []stubBackend{
{name: "metal", available: false},
{name: "rocm", available: true},
{name: "llama_cpp", available: true},
},
want: "rocm",
},
{
name: "metal_rocm_unavailable_prefers_llama_cpp",
backends: []stubBackend{
{name: "metal", available: false},
{name: "rocm", available: false},
{name: "llama_cpp", available: true},
},
want: "llama_cpp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resetBackends(t)
for i := range tt.backends {
Register(&tt.backends[i])
}
b, err := Default()
require.NoError(t, err)
assert.Equal(t, tt.want, b.Name())
})
}
}
func TestDefault_Good_FallbackToAny(t *testing.T) {
resetBackends(t)
// A custom backend that's not in the priority list.
Register(&stubBackend{name: "custom_gpu", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_gpu", b.Name(), "should fall back to any available backend")
}
func TestDefault_Bad_NoBackends(t *testing.T) {
resetBackends(t)
_, err := Default()
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestDefault_Bad_NoneAvailable(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: false})
Register(&stubBackend{name: "rocm", available: false})
_, err := Default()
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestDefault_Ugly_SkipsUnavailablePreferred(t *testing.T) {
resetBackends(t)
// metal registered but not available; custom_gpu is available.
Register(&stubBackend{name: "metal", available: false})
Register(&stubBackend{name: "custom_gpu", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_gpu", b.Name(), "should skip unavailable preferred and use any available")
}
// --- LoadModel ---
func TestLoadModel_Good_DefaultBackend(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
m, err := LoadModel("/path/to/model")
require.NoError(t, err)
require.NotNil(t, m)
sm := m.(*stubTextModel)
assert.Equal(t, "metal", sm.backend)
assert.Equal(t, "/path/to/model", sm.path)
require.NoError(t, m.Close())
}
func TestLoadModel_Good_ExplicitBackend(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
Register(&stubBackend{name: "rocm", available: true})
m, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.NoError(t, err)
require.NotNil(t, m)
sm := m.(*stubTextModel)
assert.Equal(t, "rocm", sm.backend, "should use explicitly requested backend")
require.NoError(t, m.Close())
}
func TestLoadModel_Bad_NoBackends(t *testing.T) {
resetBackends(t)
_, err := LoadModel("/path/to/model")
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestLoadModel_Bad_ExplicitBackendNotRegistered(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.Error(t, err)
assert.Contains(t, err.Error(), "backend \"rocm\" not registered")
}
func TestLoadModel_Bad_ExplicitBackendNotAvailable(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "rocm", available: false})
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.Error(t, err)
assert.Contains(t, err.Error(), "backend \"rocm\" not available")
}
func TestLoadModel_Bad_BackendLoadError(t *testing.T) {
resetBackends(t)
Register(&stubBackend{
name: "broken",
available: true,
loadErr: fmt.Errorf("GPU out of memory"),
})
_, err := LoadModel("/path/to/model", WithBackend("broken"))
require.Error(t, err)
assert.Contains(t, err.Error(), "GPU out of memory")
}
func TestLoadModel_Good_PassesOptionsThrough(t *testing.T) {
resetBackends(t)
// Use a backend that captures the load path.
Register(&stubBackend{name: "metal", available: true})
m, err := LoadModel("/models/gemma3-1b",
WithContextLen(4096),
WithGPULayers(24),
)
require.NoError(t, err)
sm := m.(*stubTextModel)
assert.Equal(t, "/models/gemma3-1b", sm.path)
require.NoError(t, m.Close())
}
func TestLoadModel_Ugly_DefaultBackendLoadError(t *testing.T) {
resetBackends(t)
Register(&stubBackend{
name: "metal",
available: true,
loadErr: fmt.Errorf("model not found"),
})
_, err := LoadModel("/nonexistent/model")
require.Error(t, err)
assert.Contains(t, err.Error(), "model not found")
}
// --- Type assertions (compile-time checks) ---
func TestTypes_Good_InterfaceCompliance(t *testing.T) {
// Verify stubBackend implements Backend.
var _ Backend = (*stubBackend)(nil)
// Verify stubTextModel implements TextModel.
var _ TextModel = (*stubTextModel)(nil)
}
// --- Struct types ---
func TestToken_Good(t *testing.T) {
tok := Token{ID: 42, Text: "hello"}
assert.Equal(t, int32(42), tok.ID)
assert.Equal(t, "hello", tok.Text)
}
func TestMessage_Good(t *testing.T) {
msg := Message{Role: "user", Content: "Hi there"}
assert.Equal(t, "user", msg.Role)
assert.Equal(t, "Hi there", msg.Content)
}
func TestClassifyResult_Good(t *testing.T) {
cr := ClassifyResult{
Token: Token{ID: 1, Text: "yes"},
Logits: []float32{0.1, 0.9},
}
assert.Equal(t, int32(1), cr.Token.ID)
assert.Equal(t, "yes", cr.Token.Text)
assert.Len(t, cr.Logits, 2)
}
func TestBatchResult_Good(t *testing.T) {
br := BatchResult{
Tokens: []Token{{ID: 1, Text: "a"}, {ID: 2, Text: "b"}},
Err: nil,
}
assert.Len(t, br.Tokens, 2)
assert.NoError(t, br.Err)
}
func TestBatchResult_Bad(t *testing.T) {
br := BatchResult{
Tokens: nil,
Err: fmt.Errorf("OOM"),
}
assert.Nil(t, br.Tokens)
assert.Error(t, br.Err)
}
func TestModelInfo_Good(t *testing.T) {
info := ModelInfo{
Architecture: "gemma3",
VocabSize: 256128,
NumLayers: 26,
HiddenSize: 1152,
QuantBits: 4,
QuantGroup: 64,
}
assert.Equal(t, "gemma3", info.Architecture)
assert.Equal(t, 256128, info.VocabSize)
assert.Equal(t, 26, info.NumLayers)
assert.Equal(t, 1152, info.HiddenSize)
assert.Equal(t, 4, info.QuantBits)
assert.Equal(t, 64, info.QuantGroup)
}
func TestGenerateMetrics_Good(t *testing.T) {
m := GenerateMetrics{
PromptTokens: 100,
GeneratedTokens: 50,
PrefillTokensPerSec: 1000.0,
DecodeTokensPerSec: 25.0,
PeakMemoryBytes: 1 << 30, // 1 GiB
ActiveMemoryBytes: 512 << 20, // 512 MiB
}
assert.Equal(t, 100, m.PromptTokens)
assert.Equal(t, 50, m.GeneratedTokens)
assert.InDelta(t, 1000.0, m.PrefillTokensPerSec, 0.01)
assert.InDelta(t, 25.0, m.DecodeTokensPerSec, 0.01)
assert.Equal(t, uint64(1<<30), m.PeakMemoryBytes)
assert.Equal(t, uint64(512<<20), m.ActiveMemoryBytes)
}
// --- Concurrent registry access ---
func TestRegistry_Good_ConcurrentAccess(t *testing.T) {
// Verify the registry is safe for concurrent reads and writes.
// The -race flag will catch data races if the mutex is broken.
resetBackends(t)
var wg sync.WaitGroup
// Concurrent writers.
for i := range 20 {
wg.Add(1)
go func(id int) {
defer wg.Done()
Register(&stubBackend{
name: fmt.Sprintf("backend_%d", id),
available: true,
})
}(i)
}
// Concurrent readers interleaved with writers.
for range 20 {
wg.Add(1)
go func() {
defer wg.Done()
_ = List()
}()
}
for range 20 {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = Get("backend_0")
}()
}
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
_, _ = Default()
}()
}
wg.Wait()
// After all goroutines finish, verify all 20 backends are registered.
names := List()
assert.Len(t, names, 20, "all 20 backends should be registered after concurrent writes")
}
// --- Register overwrite count ---
func TestRegister_Ugly_OverwriteKeepsCount(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "alpha", available: true})
Register(&stubBackend{name: "beta", available: true})
Register(&stubBackend{name: "alpha", available: false}) // overwrite
names := List()
assert.Len(t, names, 2, "overwriting a backend should not increase the count")
}
// --- Default with all preferred unavailable and custom available ---
func TestDefault_Ugly_AllPreferredUnavailableCustomAvailable(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: false})
Register(&stubBackend{name: "rocm", available: false})
Register(&stubBackend{name: "llama_cpp", available: false})
Register(&stubBackend{name: "custom_vulkan", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_vulkan", b.Name(),
"should fall back to custom backend when all preferred backends are unavailable")
}
// --- Default with multiple custom backends ---
func TestDefault_Ugly_MultipleCustomBackends(t *testing.T) {
resetBackends(t)
// Only non-preferred backends registered — one available, one not.
Register(&stubBackend{name: "custom_a", available: false})
Register(&stubBackend{name: "custom_b", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_b", b.Name(),
"should find the available custom backend in the fallback loop")
}
// --- LoadModel option forwarding ---
func TestLoadModel_Good_ExplicitBackendForwardsOptions(t *testing.T) {
resetBackends(t)
cb := &capturingBackend{name: "cap", available: true}
Register(cb)
opts := []LoadOption{
WithBackend("cap"),
WithContextLen(4096),
WithGPULayers(16),
}
m, err := LoadModel("/path/to/model", opts...)
require.NoError(t, err)
require.NotNil(t, m)
// The capturing backend should have received all options.
assert.Len(t, cb.capturedOpts, len(opts),
"all LoadOptions should be forwarded to the backend")
// Verify the forwarded options produce the correct config.
cfg := ApplyLoadOpts(cb.capturedOpts)
assert.Equal(t, "cap", cfg.Backend)
assert.Equal(t, 4096, cfg.ContextLen)
assert.Equal(t, 16, cfg.GPULayers)
require.NoError(t, m.Close())
}
func TestLoadModel_Good_DefaultBackendForwardsOptions(t *testing.T) {
resetBackends(t)
cb := &capturingBackend{name: "metal", available: true}
Register(cb)
opts := []LoadOption{
WithContextLen(8192),
WithGPULayers(-1),
WithParallelSlots(2),
}
m, err := LoadModel("/path/to/model", opts...)
require.NoError(t, err)
require.NotNil(t, m)
// The default backend should have received all options.
assert.Len(t, cb.capturedOpts, len(opts),
"all LoadOptions should be forwarded to the default backend")
cfg := ApplyLoadOpts(cb.capturedOpts)
assert.Equal(t, 8192, cfg.ContextLen)
assert.Equal(t, -1, cfg.GPULayers)
assert.Equal(t, 2, cfg.ParallelSlots)
require.NoError(t, m.Close())
}
// --- Default preference order does not depend on registration order ---
func TestDefault_Good_RegistrationOrderIrrelevant(t *testing.T) {
// Register in reverse priority order — metal should still be chosen.
resetBackends(t)
Register(&stubBackend{name: "llama_cpp", available: true})
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "metal", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "metal", b.Name(),
"metal should win regardless of registration order")
// Register in yet another order.
resetBackends(t)
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "metal", available: true})
Register(&stubBackend{name: "llama_cpp", available: true})
b, err = Default()
require.NoError(t, err)
assert.Equal(t, "metal", b.Name(),
"metal should win regardless of registration order")
}
// --- LoadModel with empty path ---
func TestLoadModel_Ugly_EmptyPath(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
// Empty path is accepted at this layer — backend decides what to do.
m, err := LoadModel("")
require.NoError(t, err)
sm := m.(*stubTextModel)
assert.Equal(t, "", sm.path, "empty path should be forwarded to the backend as-is")
require.NoError(t, m.Close())
}
// --- Get after register and overwrite ---
func TestGet_Good_AfterOverwrite(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "gpu", available: false})
Register(&stubBackend{name: "gpu", available: true}) // overwrite
b, ok := Get("gpu")
require.True(t, ok)
assert.True(t, b.Available(), "Get should return the most recently registered backend")
}
// --- List returns new slice each call ---
func TestList_Good_IndependentSlices(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "alpha", available: true})
a := List()
b := List()
assert.Equal(t, a, b, "both calls should return the same names")
// Mutating one slice should not affect the other.
a[0] = "mutated"
c := List()
assert.NotEqual(t, a[0], c[0], "List should return independent slices")
}