test(inference): add comprehensive tests for all exported API
Cover options (GenerateConfig defaults, all With* options, ApplyGenerateOpts/ ApplyLoadOpts), backend registry (Register, Get, List, Default priority order metal > rocm > llama_cpp), LoadModel routing (explicit/auto backend, error paths), and Discover (model directory scanning, quantisation, edge cases). 69 tests, 100% statement coverage, race-clean. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
15ee86ec62
commit
d76448d4a9
5 changed files with 1074 additions and 0 deletions
250
discover_test.go
Normal file
250
discover_test.go
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
package inference
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- test helpers for discover ---
|
||||||
|
|
||||||
|
// createModelDir creates a fake model directory with config.json and n safetensors files.
|
||||||
|
func createModelDir(t *testing.T, dir string, config map[string]any, numSafetensors int) {
|
||||||
|
t.Helper()
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
|
||||||
|
if config != nil {
|
||||||
|
data, err := json.Marshal(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), data, 0o644))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range numSafetensors {
|
||||||
|
fname := filepath.Join(dir, fmt.Sprintf("model-%05d-of-%05d.safetensors", i+1, numSafetensors))
|
||||||
|
require.NoError(t, os.WriteFile(fname, []byte("fake"), 0o644))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Discover ---
|
||||||
|
|
||||||
|
func TestDiscover_Good_SingleModel(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "gemma3-1b"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
|
||||||
|
m := models[0]
|
||||||
|
assert.Equal(t, "gemma3", m.ModelType)
|
||||||
|
assert.Equal(t, 1, m.NumFiles)
|
||||||
|
assert.Equal(t, 0, m.QuantBits)
|
||||||
|
assert.Equal(t, 0, m.QuantGroup)
|
||||||
|
assert.True(t, filepath.IsAbs(m.Path), "path should be absolute")
|
||||||
|
assert.Contains(t, m.Path, "gemma3-1b")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_MultipleModels(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "gemma3-1b"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
}, 1)
|
||||||
|
createModelDir(t, filepath.Join(base, "qwen3-4b"), map[string]any{
|
||||||
|
"model_type": "qwen3",
|
||||||
|
}, 4)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 2)
|
||||||
|
|
||||||
|
// Collect model types for assertion (order may vary due to ReadDir).
|
||||||
|
types := make([]string, len(models))
|
||||||
|
for i, m := range models {
|
||||||
|
types[i] = m.ModelType
|
||||||
|
}
|
||||||
|
assert.ElementsMatch(t, []string{"gemma3", "qwen3"}, types)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_Quantised(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "gemma3-1b-4bit"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
"quantization": map[string]any{
|
||||||
|
"bits": 4,
|
||||||
|
"group_size": 64,
|
||||||
|
},
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
|
||||||
|
m := models[0]
|
||||||
|
assert.Equal(t, 4, m.QuantBits)
|
||||||
|
assert.Equal(t, 64, m.QuantGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_BaseDirIsModel(t *testing.T) {
|
||||||
|
// The base directory itself is a model directory.
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, base, map[string]any{
|
||||||
|
"model_type": "llama",
|
||||||
|
}, 2)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
|
||||||
|
m := models[0]
|
||||||
|
assert.Equal(t, "llama", m.ModelType)
|
||||||
|
assert.Equal(t, 2, m.NumFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_BaseDirPlusSubdir(t *testing.T) {
|
||||||
|
// Both base dir and a subdir are valid models. Base should appear first.
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, base, map[string]any{
|
||||||
|
"model_type": "parent_model",
|
||||||
|
}, 1)
|
||||||
|
createModelDir(t, filepath.Join(base, "child"), map[string]any{
|
||||||
|
"model_type": "child_model",
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 2)
|
||||||
|
|
||||||
|
// Base dir should appear first (prepended).
|
||||||
|
assert.Equal(t, "parent_model", models[0].ModelType, "base dir model should be first")
|
||||||
|
assert.Equal(t, "child_model", models[1].ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_EmptyDir(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, models)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Bad_NonexistentDir(t *testing.T) {
|
||||||
|
_, err := Discover("/nonexistent/path/that/should/not/exist")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Bad_NoSafetensors(t *testing.T) {
|
||||||
|
// Directory with config.json but no .safetensors files.
|
||||||
|
base := t.TempDir()
|
||||||
|
dir := filepath.Join(base, "incomplete")
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
|
||||||
|
config := map[string]any{"model_type": "gemma3"}
|
||||||
|
data, err := json.Marshal(config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), data, 0o644))
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, models, "directory without safetensors should be skipped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Bad_NoConfig(t *testing.T) {
|
||||||
|
// Directory with .safetensors but no config.json.
|
||||||
|
base := t.TempDir()
|
||||||
|
dir := filepath.Join(base, "no-config")
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("fake"), 0o644))
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, models, "directory without config.json should be skipped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Bad_InvalidJSON(t *testing.T) {
|
||||||
|
// config.json exists but contains invalid JSON. Should still count safetensors files
|
||||||
|
// since json.Unmarshal failure is silently ignored.
|
||||||
|
base := t.TempDir()
|
||||||
|
dir := filepath.Join(base, "bad-json")
|
||||||
|
require.NoError(t, os.MkdirAll(dir, 0o755))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte("{invalid}"), 0o644))
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("fake"), 0o644))
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, "", models[0].ModelType, "invalid JSON should yield empty model_type")
|
||||||
|
assert.Equal(t, 1, models[0].NumFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Ugly_SkipsRegularFiles(t *testing.T) {
|
||||||
|
// Regular files in base dir should be silently skipped.
|
||||||
|
base := t.TempDir()
|
||||||
|
require.NoError(t, os.WriteFile(filepath.Join(base, "README.md"), []byte("hello"), 0o644))
|
||||||
|
|
||||||
|
createModelDir(t, filepath.Join(base, "real-model"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, "gemma3", models[0].ModelType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Ugly_MissingModelType(t *testing.T) {
|
||||||
|
// config.json without model_type field.
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "no-type"), map[string]any{
|
||||||
|
"vocab_size": 32000,
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, "", models[0].ModelType, "missing model_type should yield empty string")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Ugly_NoQuantisation(t *testing.T) {
|
||||||
|
// config.json without quantization key — QuantBits/QuantGroup should be 0.
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "fp16"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, 0, models[0].QuantBits)
|
||||||
|
assert.Equal(t, 0, models[0].QuantGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_MultipleSafetensors(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "large-model"), map[string]any{
|
||||||
|
"model_type": "llama",
|
||||||
|
}, 8)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.Equal(t, 8, models[0].NumFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiscover_Good_AbsolutePath(t *testing.T) {
|
||||||
|
base := t.TempDir()
|
||||||
|
createModelDir(t, filepath.Join(base, "test-model"), map[string]any{
|
||||||
|
"model_type": "gemma3",
|
||||||
|
}, 1)
|
||||||
|
|
||||||
|
models, err := Discover(base)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, models, 1)
|
||||||
|
assert.True(t, filepath.IsAbs(models[0].Path), "discovered path must be absolute")
|
||||||
|
}
|
||||||
8
go.mod
8
go.mod
|
|
@ -1,3 +1,11 @@
|
||||||
module forge.lthn.ai/core/go-inference
|
module forge.lthn.ai/core/go-inference
|
||||||
|
|
||||||
go 1.25.5
|
go 1.25.5
|
||||||
|
|
||||||
|
require github.com/stretchr/testify v1.11.1
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
|
|
|
||||||
10
go.sum
Normal file
10
go.sum
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
456
inference_test.go
Normal file
456
inference_test.go
Normal file
|
|
@ -0,0 +1,456 @@
|
||||||
|
package inference
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- test helpers ---
|
||||||
|
|
||||||
|
// resetBackends clears the global registry for test isolation.
|
||||||
|
func resetBackends(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
backendsMu.Lock()
|
||||||
|
defer backendsMu.Unlock()
|
||||||
|
backends = map[string]Backend{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stubBackend is a minimal Backend implementation for testing.
|
||||||
|
type stubBackend struct {
|
||||||
|
name string
|
||||||
|
available bool
|
||||||
|
loadErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubBackend) Name() string { return s.name }
|
||||||
|
func (s *stubBackend) Available() bool { return s.available }
|
||||||
|
func (s *stubBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
|
||||||
|
if s.loadErr != nil {
|
||||||
|
return nil, s.loadErr
|
||||||
|
}
|
||||||
|
return &stubTextModel{backend: s.name, path: path}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stubTextModel is a minimal TextModel for testing LoadModel routing.
|
||||||
|
type stubTextModel struct {
|
||||||
|
backend string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *stubTextModel) Generate(_ context.Context, _ string, _ ...GenerateOption) iter.Seq[Token] {
|
||||||
|
return func(yield func(Token) bool) {}
|
||||||
|
}
|
||||||
|
func (m *stubTextModel) Chat(_ context.Context, _ []Message, _ ...GenerateOption) iter.Seq[Token] {
|
||||||
|
return func(yield func(Token) bool) {}
|
||||||
|
}
|
||||||
|
func (m *stubTextModel) Classify(_ context.Context, _ []string, _ ...GenerateOption) ([]ClassifyResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *stubTextModel) BatchGenerate(_ context.Context, _ []string, _ ...GenerateOption) ([]BatchResult, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
func (m *stubTextModel) ModelType() string { return "stub" }
|
||||||
|
func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} }
|
||||||
|
func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} }
|
||||||
|
func (m *stubTextModel) Err() error { return nil }
|
||||||
|
func (m *stubTextModel) Close() error { return nil }
|
||||||
|
|
||||||
|
// --- Register ---
|
||||||
|
|
||||||
|
func TestRegister_Good(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
b := &stubBackend{name: "test_backend", available: true}
|
||||||
|
Register(b)
|
||||||
|
|
||||||
|
got, ok := Get("test_backend")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "test_backend", got.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegister_Good_Multiple(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "alpha", available: true})
|
||||||
|
Register(&stubBackend{name: "beta", available: true})
|
||||||
|
Register(&stubBackend{name: "gamma", available: true})
|
||||||
|
|
||||||
|
names := List()
|
||||||
|
sort.Strings(names)
|
||||||
|
assert.Equal(t, []string{"alpha", "beta", "gamma"}, names)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegister_Ugly_Overwrites(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
b1 := &stubBackend{name: "dup", available: false}
|
||||||
|
b2 := &stubBackend{name: "dup", available: true}
|
||||||
|
|
||||||
|
Register(b1)
|
||||||
|
Register(b2)
|
||||||
|
|
||||||
|
got, ok := Get("dup")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.True(t, got.Available(), "second registration should overwrite the first")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Get ---
|
||||||
|
|
||||||
|
func TestGet_Good(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "exists", available: true})
|
||||||
|
|
||||||
|
b, ok := Get("exists")
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Equal(t, "exists", b.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGet_Bad(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
b, ok := Get("nonexistent")
|
||||||
|
assert.False(t, ok)
|
||||||
|
assert.Nil(t, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List ---
|
||||||
|
|
||||||
|
func TestList_Good_Empty(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
names := List()
|
||||||
|
assert.Empty(t, names)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_Good_Populated(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "a", available: true})
|
||||||
|
Register(&stubBackend{name: "b", available: true})
|
||||||
|
|
||||||
|
names := List()
|
||||||
|
sort.Strings(names)
|
||||||
|
assert.Equal(t, []string{"a", "b"}, names)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Default ---
|
||||||
|
|
||||||
|
func TestDefault_Good_Metal(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "metal", available: true})
|
||||||
|
Register(&stubBackend{name: "rocm", available: true})
|
||||||
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
||||||
|
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "metal", b.Name(), "metal should be preferred when available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Good_Rocm(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "rocm", available: true})
|
||||||
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
||||||
|
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "rocm", b.Name(), "rocm should be preferred when metal is absent")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Good_LlamaCpp(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "llama_cpp", available: true})
|
||||||
|
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "llama_cpp", b.Name(), "llama_cpp should be used as fallback")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Good_PriorityOrder(t *testing.T) {
|
||||||
|
// Full priority test: metal > rocm > llama_cpp
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
backends []stubBackend
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "all_available_prefers_metal",
|
||||||
|
backends: []stubBackend{
|
||||||
|
{name: "llama_cpp", available: true},
|
||||||
|
{name: "rocm", available: true},
|
||||||
|
{name: "metal", available: true},
|
||||||
|
},
|
||||||
|
want: "metal",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "metal_unavailable_prefers_rocm",
|
||||||
|
backends: []stubBackend{
|
||||||
|
{name: "metal", available: false},
|
||||||
|
{name: "rocm", available: true},
|
||||||
|
{name: "llama_cpp", available: true},
|
||||||
|
},
|
||||||
|
want: "rocm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "metal_rocm_unavailable_prefers_llama_cpp",
|
||||||
|
backends: []stubBackend{
|
||||||
|
{name: "metal", available: false},
|
||||||
|
{name: "rocm", available: false},
|
||||||
|
{name: "llama_cpp", available: true},
|
||||||
|
},
|
||||||
|
want: "llama_cpp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
for i := range tt.backends {
|
||||||
|
Register(&tt.backends[i])
|
||||||
|
}
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, b.Name())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Good_FallbackToAny(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
// A custom backend that's not in the priority list.
|
||||||
|
Register(&stubBackend{name: "custom_gpu", available: true})
|
||||||
|
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "custom_gpu", b.Name(), "should fall back to any available backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Bad_NoBackends(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
_, err := Default()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no backends registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Bad_NoneAvailable(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "metal", available: false})
|
||||||
|
Register(&stubBackend{name: "rocm", available: false})
|
||||||
|
|
||||||
|
_, err := Default()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no backends registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefault_Ugly_SkipsUnavailablePreferred(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
// metal registered but not available; custom_gpu is available.
|
||||||
|
Register(&stubBackend{name: "metal", available: false})
|
||||||
|
Register(&stubBackend{name: "custom_gpu", available: true})
|
||||||
|
|
||||||
|
b, err := Default()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "custom_gpu", b.Name(), "should skip unavailable preferred and use any available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- LoadModel ---
|
||||||
|
|
||||||
|
func TestLoadModel_Good_DefaultBackend(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "metal", available: true})
|
||||||
|
|
||||||
|
m, err := LoadModel("/path/to/model")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, m)
|
||||||
|
|
||||||
|
sm := m.(*stubTextModel)
|
||||||
|
assert.Equal(t, "metal", sm.backend)
|
||||||
|
assert.Equal(t, "/path/to/model", sm.path)
|
||||||
|
require.NoError(t, m.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Good_ExplicitBackend(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "metal", available: true})
|
||||||
|
Register(&stubBackend{name: "rocm", available: true})
|
||||||
|
|
||||||
|
m, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, m)
|
||||||
|
|
||||||
|
sm := m.(*stubTextModel)
|
||||||
|
assert.Equal(t, "rocm", sm.backend, "should use explicitly requested backend")
|
||||||
|
require.NoError(t, m.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Bad_NoBackends(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
_, err := LoadModel("/path/to/model")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no backends registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Bad_ExplicitBackendNotRegistered(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "metal", available: true})
|
||||||
|
|
||||||
|
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "backend \"rocm\" not registered")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Bad_ExplicitBackendNotAvailable(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{name: "rocm", available: false})
|
||||||
|
|
||||||
|
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "backend \"rocm\" not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Bad_BackendLoadError(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{
|
||||||
|
name: "broken",
|
||||||
|
available: true,
|
||||||
|
loadErr: fmt.Errorf("GPU out of memory"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := LoadModel("/path/to/model", WithBackend("broken"))
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "GPU out of memory")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Good_PassesOptionsThrough(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
// Use a backend that captures the load path.
|
||||||
|
Register(&stubBackend{name: "metal", available: true})
|
||||||
|
|
||||||
|
m, err := LoadModel("/models/gemma3-1b",
|
||||||
|
WithContextLen(4096),
|
||||||
|
WithGPULayers(24),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sm := m.(*stubTextModel)
|
||||||
|
assert.Equal(t, "/models/gemma3-1b", sm.path)
|
||||||
|
require.NoError(t, m.Close())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadModel_Ugly_DefaultBackendLoadError(t *testing.T) {
|
||||||
|
resetBackends(t)
|
||||||
|
|
||||||
|
Register(&stubBackend{
|
||||||
|
name: "metal",
|
||||||
|
available: true,
|
||||||
|
loadErr: fmt.Errorf("model not found"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := LoadModel("/nonexistent/model")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "model not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Type assertions (compile-time checks) ---
|
||||||
|
|
||||||
|
func TestTypes_Good_InterfaceCompliance(t *testing.T) {
|
||||||
|
// Verify stubBackend implements Backend.
|
||||||
|
var _ Backend = (*stubBackend)(nil)
|
||||||
|
// Verify stubTextModel implements TextModel.
|
||||||
|
var _ TextModel = (*stubTextModel)(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Struct types ---
|
||||||
|
|
||||||
|
func TestToken_Good(t *testing.T) {
|
||||||
|
tok := Token{ID: 42, Text: "hello"}
|
||||||
|
assert.Equal(t, int32(42), tok.ID)
|
||||||
|
assert.Equal(t, "hello", tok.Text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMessage_Good(t *testing.T) {
|
||||||
|
msg := Message{Role: "user", Content: "Hi there"}
|
||||||
|
assert.Equal(t, "user", msg.Role)
|
||||||
|
assert.Equal(t, "Hi there", msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClassifyResult_Good(t *testing.T) {
|
||||||
|
cr := ClassifyResult{
|
||||||
|
Token: Token{ID: 1, Text: "yes"},
|
||||||
|
Logits: []float32{0.1, 0.9},
|
||||||
|
}
|
||||||
|
assert.Equal(t, int32(1), cr.Token.ID)
|
||||||
|
assert.Equal(t, "yes", cr.Token.Text)
|
||||||
|
assert.Len(t, cr.Logits, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchResult_Good(t *testing.T) {
|
||||||
|
br := BatchResult{
|
||||||
|
Tokens: []Token{{ID: 1, Text: "a"}, {ID: 2, Text: "b"}},
|
||||||
|
Err: nil,
|
||||||
|
}
|
||||||
|
assert.Len(t, br.Tokens, 2)
|
||||||
|
assert.NoError(t, br.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBatchResult_Bad(t *testing.T) {
|
||||||
|
br := BatchResult{
|
||||||
|
Tokens: nil,
|
||||||
|
Err: fmt.Errorf("OOM"),
|
||||||
|
}
|
||||||
|
assert.Nil(t, br.Tokens)
|
||||||
|
assert.Error(t, br.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelInfo_Good(t *testing.T) {
|
||||||
|
info := ModelInfo{
|
||||||
|
Architecture: "gemma3",
|
||||||
|
VocabSize: 256128,
|
||||||
|
NumLayers: 26,
|
||||||
|
HiddenSize: 1152,
|
||||||
|
QuantBits: 4,
|
||||||
|
QuantGroup: 64,
|
||||||
|
}
|
||||||
|
assert.Equal(t, "gemma3", info.Architecture)
|
||||||
|
assert.Equal(t, 256128, info.VocabSize)
|
||||||
|
assert.Equal(t, 26, info.NumLayers)
|
||||||
|
assert.Equal(t, 1152, info.HiddenSize)
|
||||||
|
assert.Equal(t, 4, info.QuantBits)
|
||||||
|
assert.Equal(t, 64, info.QuantGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateMetrics_Good(t *testing.T) {
|
||||||
|
m := GenerateMetrics{
|
||||||
|
PromptTokens: 100,
|
||||||
|
GeneratedTokens: 50,
|
||||||
|
PrefillTokensPerSec: 1000.0,
|
||||||
|
DecodeTokensPerSec: 25.0,
|
||||||
|
PeakMemoryBytes: 1 << 30, // 1 GiB
|
||||||
|
ActiveMemoryBytes: 512 << 20, // 512 MiB
|
||||||
|
}
|
||||||
|
assert.Equal(t, 100, m.PromptTokens)
|
||||||
|
assert.Equal(t, 50, m.GeneratedTokens)
|
||||||
|
assert.InDelta(t, 1000.0, m.PrefillTokensPerSec, 0.01)
|
||||||
|
assert.InDelta(t, 25.0, m.DecodeTokensPerSec, 0.01)
|
||||||
|
assert.Equal(t, uint64(1<<30), m.PeakMemoryBytes)
|
||||||
|
assert.Equal(t, uint64(512<<20), m.ActiveMemoryBytes)
|
||||||
|
}
|
||||||
350
options_test.go
Normal file
350
options_test.go
Normal file
|
|
@ -0,0 +1,350 @@
|
||||||
|
package inference
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- GenerateConfig defaults ---
|
||||||
|
|
||||||
|
func TestDefaultGenerateConfig_Good(t *testing.T) {
|
||||||
|
cfg := DefaultGenerateConfig()
|
||||||
|
assert.Equal(t, 256, cfg.MaxTokens, "default MaxTokens should be 256")
|
||||||
|
assert.Equal(t, float32(0.0), cfg.Temperature, "default Temperature should be 0.0 (greedy)")
|
||||||
|
assert.Equal(t, 0, cfg.TopK, "default TopK should be 0 (disabled)")
|
||||||
|
assert.Equal(t, float32(0.0), cfg.TopP, "default TopP should be 0.0 (disabled)")
|
||||||
|
assert.Nil(t, cfg.StopTokens, "default StopTokens should be nil")
|
||||||
|
assert.Equal(t, float32(0.0), cfg.RepeatPenalty, "default RepeatPenalty should be 0.0 (disabled)")
|
||||||
|
assert.False(t, cfg.ReturnLogits, "default ReturnLogits should be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithMaxTokens ---
|
||||||
|
|
||||||
|
func TestWithMaxTokens_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"small", 32, 32},
|
||||||
|
{"medium", 512, 512},
|
||||||
|
{"large", 4096, 4096},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(tt.val)})
|
||||||
|
assert.Equal(t, tt.want, cfg.MaxTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithMaxTokens_Bad(t *testing.T) {
|
||||||
|
// Zero and negative values are accepted (no validation in options layer)
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(0)})
|
||||||
|
assert.Equal(t, 0, cfg.MaxTokens)
|
||||||
|
|
||||||
|
cfg = ApplyGenerateOpts([]GenerateOption{WithMaxTokens(-1)})
|
||||||
|
assert.Equal(t, -1, cfg.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithTemperature ---
|
||||||
|
|
||||||
|
func TestWithTemperature_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val float32
|
||||||
|
want float32
|
||||||
|
}{
|
||||||
|
{"greedy", 0.0, 0.0},
|
||||||
|
{"low", 0.3, 0.3},
|
||||||
|
{"default_creative", 0.7, 0.7},
|
||||||
|
{"high", 1.5, 1.5},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(tt.val)})
|
||||||
|
assert.InDelta(t, tt.want, cfg.Temperature, 0.0001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithTopK ---
|
||||||
|
|
||||||
|
func TestWithTopK_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"disabled", 0, 0},
|
||||||
|
{"small", 10, 10},
|
||||||
|
{"typical", 40, 40},
|
||||||
|
{"large", 100, 100},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(tt.val)})
|
||||||
|
assert.Equal(t, tt.want, cfg.TopK)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithTopP ---
|
||||||
|
|
||||||
|
func TestWithTopP_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val float32
|
||||||
|
want float32
|
||||||
|
}{
|
||||||
|
{"disabled", 0.0, 0.0},
|
||||||
|
{"tight", 0.5, 0.5},
|
||||||
|
{"typical", 0.9, 0.9},
|
||||||
|
{"full", 1.0, 1.0},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(tt.val)})
|
||||||
|
assert.InDelta(t, tt.want, cfg.TopP, 0.0001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithStopTokens ---
|
||||||
|
|
||||||
|
func TestWithStopTokens_Good(t *testing.T) {
|
||||||
|
t.Run("single", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1)})
|
||||||
|
assert.Equal(t, []int32{1}, cfg.StopTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1, 2, 3)})
|
||||||
|
assert.Equal(t, []int32{1, 2, 3}, cfg.StopTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithStopTokens_Ugly(t *testing.T) {
|
||||||
|
// Last call wins — stop tokens are replaced, not merged.
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{
|
||||||
|
WithStopTokens(1, 2),
|
||||||
|
WithStopTokens(3, 4, 5),
|
||||||
|
})
|
||||||
|
assert.Equal(t, []int32{3, 4, 5}, cfg.StopTokens, "last WithStopTokens should win")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithRepeatPenalty ---
|
||||||
|
|
||||||
|
func TestWithRepeatPenalty_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val float32
|
||||||
|
want float32
|
||||||
|
}{
|
||||||
|
{"disabled", 0.0, 0.0},
|
||||||
|
{"no_penalty", 1.0, 1.0},
|
||||||
|
{"typical", 1.1, 1.1},
|
||||||
|
{"strong", 2.0, 2.0},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(tt.val)})
|
||||||
|
assert.InDelta(t, tt.want, cfg.RepeatPenalty, 0.0001)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithLogits ---
|
||||||
|
|
||||||
|
func TestWithLogits_Good(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{WithLogits()})
|
||||||
|
assert.True(t, cfg.ReturnLogits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ApplyGenerateOpts ---
|
||||||
|
|
||||||
|
func TestApplyGenerateOpts_Good(t *testing.T) {
|
||||||
|
t.Run("nil_opts_returns_defaults", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts(nil)
|
||||||
|
def := DefaultGenerateConfig()
|
||||||
|
assert.Equal(t, def, cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_opts_returns_defaults", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{})
|
||||||
|
def := DefaultGenerateConfig()
|
||||||
|
assert.Equal(t, def, cfg)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all_options_combined", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{
|
||||||
|
WithMaxTokens(128),
|
||||||
|
WithTemperature(0.7),
|
||||||
|
WithTopK(40),
|
||||||
|
WithTopP(0.9),
|
||||||
|
WithStopTokens(1, 2),
|
||||||
|
WithRepeatPenalty(1.1),
|
||||||
|
WithLogits(),
|
||||||
|
})
|
||||||
|
assert.Equal(t, 128, cfg.MaxTokens)
|
||||||
|
assert.InDelta(t, 0.7, cfg.Temperature, 0.0001)
|
||||||
|
assert.Equal(t, 40, cfg.TopK)
|
||||||
|
assert.InDelta(t, 0.9, cfg.TopP, 0.0001)
|
||||||
|
assert.Equal(t, []int32{1, 2}, cfg.StopTokens)
|
||||||
|
assert.InDelta(t, 1.1, cfg.RepeatPenalty, 0.0001)
|
||||||
|
assert.True(t, cfg.ReturnLogits)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyGenerateOpts_Ugly(t *testing.T) {
|
||||||
|
t.Run("last_option_wins", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{
|
||||||
|
WithMaxTokens(100),
|
||||||
|
WithMaxTokens(200),
|
||||||
|
WithMaxTokens(300),
|
||||||
|
})
|
||||||
|
assert.Equal(t, 300, cfg.MaxTokens, "last WithMaxTokens should win")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("temperature_override", func(t *testing.T) {
|
||||||
|
cfg := ApplyGenerateOpts([]GenerateOption{
|
||||||
|
WithTemperature(0.5),
|
||||||
|
WithTemperature(1.0),
|
||||||
|
})
|
||||||
|
assert.InDelta(t, 1.0, cfg.Temperature, 0.0001, "last WithTemperature should win")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- LoadConfig defaults ---
|
||||||
|
|
||||||
|
func TestApplyLoadOpts_Good_Defaults(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts(nil)
|
||||||
|
assert.Equal(t, "", cfg.Backend, "default Backend should be empty (auto-detect)")
|
||||||
|
assert.Equal(t, 0, cfg.ContextLen, "default ContextLen should be 0 (model default)")
|
||||||
|
assert.Equal(t, -1, cfg.GPULayers, "default GPULayers should be -1 (all layers)")
|
||||||
|
assert.Equal(t, 0, cfg.ParallelSlots, "default ParallelSlots should be 0 (server default)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithBackend ---
|
||||||
|
|
||||||
|
func TestWithBackend_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
backend string
|
||||||
|
}{
|
||||||
|
{"metal", "metal"},
|
||||||
|
{"rocm", "rocm"},
|
||||||
|
{"llama_cpp", "llama_cpp"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithBackend(tt.backend)})
|
||||||
|
assert.Equal(t, tt.backend, cfg.Backend)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithBackend_Bad(t *testing.T) {
|
||||||
|
// Empty string is valid at the options layer (means auto-detect).
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithBackend("")})
|
||||||
|
assert.Equal(t, "", cfg.Backend)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithContextLen ---
|
||||||
|
|
||||||
|
func TestWithContextLen_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"small", 2048, 2048},
|
||||||
|
{"medium", 4096, 4096},
|
||||||
|
{"large", 32768, 32768},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithContextLen(tt.val)})
|
||||||
|
assert.Equal(t, tt.want, cfg.ContextLen)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithGPULayers ---
|
||||||
|
|
||||||
|
func TestWithGPULayers_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"all", -1, -1},
|
||||||
|
{"none", 0, 0},
|
||||||
|
{"partial", 24, 24},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(tt.val)})
|
||||||
|
assert.Equal(t, tt.want, cfg.GPULayers)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithGPULayers_Ugly(t *testing.T) {
|
||||||
|
// Override the default -1 with 0
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(0)})
|
||||||
|
assert.Equal(t, 0, cfg.GPULayers, "WithGPULayers(0) should override default -1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- WithParallelSlots ---
|
||||||
|
|
||||||
|
func TestWithParallelSlots_Good(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
val int
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{"default", 0, 0},
|
||||||
|
{"one", 1, 1},
|
||||||
|
{"four", 4, 4},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{WithParallelSlots(tt.val)})
|
||||||
|
assert.Equal(t, tt.want, cfg.ParallelSlots)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ApplyLoadOpts combined ---
|
||||||
|
|
||||||
|
func TestApplyLoadOpts_Good_Combined(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{
|
||||||
|
WithBackend("rocm"),
|
||||||
|
WithContextLen(8192),
|
||||||
|
WithGPULayers(32),
|
||||||
|
WithParallelSlots(2),
|
||||||
|
})
|
||||||
|
assert.Equal(t, "rocm", cfg.Backend)
|
||||||
|
assert.Equal(t, 8192, cfg.ContextLen)
|
||||||
|
assert.Equal(t, 32, cfg.GPULayers)
|
||||||
|
assert.Equal(t, 2, cfg.ParallelSlots)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyLoadOpts_Ugly(t *testing.T) {
|
||||||
|
t.Run("last_option_wins", func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{
|
||||||
|
WithBackend("metal"),
|
||||||
|
WithBackend("rocm"),
|
||||||
|
})
|
||||||
|
assert.Equal(t, "rocm", cfg.Backend, "last WithBackend should win")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty_slice_returns_defaults", func(t *testing.T) {
|
||||||
|
cfg := ApplyLoadOpts([]LoadOption{})
|
||||||
|
require.Equal(t, -1, cfg.GPULayers, "empty opts should keep default GPULayers=-1")
|
||||||
|
assert.Equal(t, "", cfg.Backend)
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Add table
Reference in a new issue