test(inference): add comprehensive tests for all exported API

Cover options (GenerateConfig defaults, all With* options, ApplyGenerateOpts/
ApplyLoadOpts), backend registry (Register, Get, List, Default priority order
metal > rocm > llama_cpp), LoadModel routing (explicit/auto backend, error
paths), and Discover (model directory scanning, quantisation, edge cases).

69 tests, 100% statement coverage, race-clean.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-20 02:06:41 +00:00
parent 15ee86ec62
commit d76448d4a9
No known key found for this signature in database
GPG key ID: AF404715446AEB41
5 changed files with 1074 additions and 0 deletions

250
discover_test.go Normal file
View file

@ -0,0 +1,250 @@
package inference
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- test helpers for discover ---
// createModelDir creates a fake model directory with config.json and n safetensors files.
func createModelDir(t *testing.T, dir string, config map[string]any, numSafetensors int) {
t.Helper()
require.NoError(t, os.MkdirAll(dir, 0o755))
if config != nil {
data, err := json.Marshal(config)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), data, 0o644))
}
for i := range numSafetensors {
fname := filepath.Join(dir, fmt.Sprintf("model-%05d-of-%05d.safetensors", i+1, numSafetensors))
require.NoError(t, os.WriteFile(fname, []byte("fake"), 0o644))
}
}
// --- Discover ---
func TestDiscover_Good_SingleModel(t *testing.T) {
base := t.TempDir()
createModelDir(t, filepath.Join(base, "gemma3-1b"), map[string]any{
"model_type": "gemma3",
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
m := models[0]
assert.Equal(t, "gemma3", m.ModelType)
assert.Equal(t, 1, m.NumFiles)
assert.Equal(t, 0, m.QuantBits)
assert.Equal(t, 0, m.QuantGroup)
assert.True(t, filepath.IsAbs(m.Path), "path should be absolute")
assert.Contains(t, m.Path, "gemma3-1b")
}
func TestDiscover_Good_MultipleModels(t *testing.T) {
base := t.TempDir()
createModelDir(t, filepath.Join(base, "gemma3-1b"), map[string]any{
"model_type": "gemma3",
}, 1)
createModelDir(t, filepath.Join(base, "qwen3-4b"), map[string]any{
"model_type": "qwen3",
}, 4)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 2)
// Collect model types for assertion (order may vary due to ReadDir).
types := make([]string, len(models))
for i, m := range models {
types[i] = m.ModelType
}
assert.ElementsMatch(t, []string{"gemma3", "qwen3"}, types)
}
func TestDiscover_Good_Quantised(t *testing.T) {
base := t.TempDir()
createModelDir(t, filepath.Join(base, "gemma3-1b-4bit"), map[string]any{
"model_type": "gemma3",
"quantization": map[string]any{
"bits": 4,
"group_size": 64,
},
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
m := models[0]
assert.Equal(t, 4, m.QuantBits)
assert.Equal(t, 64, m.QuantGroup)
}
func TestDiscover_Good_BaseDirIsModel(t *testing.T) {
// The base directory itself is a model directory.
base := t.TempDir()
createModelDir(t, base, map[string]any{
"model_type": "llama",
}, 2)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
m := models[0]
assert.Equal(t, "llama", m.ModelType)
assert.Equal(t, 2, m.NumFiles)
}
func TestDiscover_Good_BaseDirPlusSubdir(t *testing.T) {
// Both base dir and a subdir are valid models. Base should appear first.
base := t.TempDir()
createModelDir(t, base, map[string]any{
"model_type": "parent_model",
}, 1)
createModelDir(t, filepath.Join(base, "child"), map[string]any{
"model_type": "child_model",
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 2)
// Base dir should appear first (prepended).
assert.Equal(t, "parent_model", models[0].ModelType, "base dir model should be first")
assert.Equal(t, "child_model", models[1].ModelType)
}
func TestDiscover_Good_EmptyDir(t *testing.T) {
base := t.TempDir()
models, err := Discover(base)
require.NoError(t, err)
assert.Empty(t, models)
}
func TestDiscover_Bad_NonexistentDir(t *testing.T) {
_, err := Discover("/nonexistent/path/that/should/not/exist")
require.Error(t, err)
}
func TestDiscover_Bad_NoSafetensors(t *testing.T) {
// Directory with config.json but no .safetensors files.
base := t.TempDir()
dir := filepath.Join(base, "incomplete")
require.NoError(t, os.MkdirAll(dir, 0o755))
config := map[string]any{"model_type": "gemma3"}
data, err := json.Marshal(config)
require.NoError(t, err)
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), data, 0o644))
models, err := Discover(base)
require.NoError(t, err)
assert.Empty(t, models, "directory without safetensors should be skipped")
}
func TestDiscover_Bad_NoConfig(t *testing.T) {
// Directory with .safetensors but no config.json.
base := t.TempDir()
dir := filepath.Join(base, "no-config")
require.NoError(t, os.MkdirAll(dir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("fake"), 0o644))
models, err := Discover(base)
require.NoError(t, err)
assert.Empty(t, models, "directory without config.json should be skipped")
}
func TestDiscover_Bad_InvalidJSON(t *testing.T) {
// config.json exists but contains invalid JSON. Should still count safetensors files
// since json.Unmarshal failure is silently ignored.
base := t.TempDir()
dir := filepath.Join(base, "bad-json")
require.NoError(t, os.MkdirAll(dir, 0o755))
require.NoError(t, os.WriteFile(filepath.Join(dir, "config.json"), []byte("{invalid}"), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(dir, "model.safetensors"), []byte("fake"), 0o644))
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.Equal(t, "", models[0].ModelType, "invalid JSON should yield empty model_type")
assert.Equal(t, 1, models[0].NumFiles)
}
func TestDiscover_Ugly_SkipsRegularFiles(t *testing.T) {
// Regular files in base dir should be silently skipped.
base := t.TempDir()
require.NoError(t, os.WriteFile(filepath.Join(base, "README.md"), []byte("hello"), 0o644))
createModelDir(t, filepath.Join(base, "real-model"), map[string]any{
"model_type": "gemma3",
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.Equal(t, "gemma3", models[0].ModelType)
}
func TestDiscover_Ugly_MissingModelType(t *testing.T) {
// config.json without model_type field.
base := t.TempDir()
createModelDir(t, filepath.Join(base, "no-type"), map[string]any{
"vocab_size": 32000,
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.Equal(t, "", models[0].ModelType, "missing model_type should yield empty string")
}
func TestDiscover_Ugly_NoQuantisation(t *testing.T) {
// config.json without quantization key — QuantBits/QuantGroup should be 0.
base := t.TempDir()
createModelDir(t, filepath.Join(base, "fp16"), map[string]any{
"model_type": "gemma3",
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.Equal(t, 0, models[0].QuantBits)
assert.Equal(t, 0, models[0].QuantGroup)
}
func TestDiscover_Good_MultipleSafetensors(t *testing.T) {
base := t.TempDir()
createModelDir(t, filepath.Join(base, "large-model"), map[string]any{
"model_type": "llama",
}, 8)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.Equal(t, 8, models[0].NumFiles)
}
func TestDiscover_Good_AbsolutePath(t *testing.T) {
base := t.TempDir()
createModelDir(t, filepath.Join(base, "test-model"), map[string]any{
"model_type": "gemma3",
}, 1)
models, err := Discover(base)
require.NoError(t, err)
require.Len(t, models, 1)
assert.True(t, filepath.IsAbs(models[0].Path), "discovered path must be absolute")
}

8
go.mod
View file

@ -1,3 +1,11 @@
module forge.lthn.ai/core/go-inference
go 1.25.5
require github.com/stretchr/testify v1.11.1
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum Normal file
View file

@ -0,0 +1,10 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

456
inference_test.go Normal file
View file

@ -0,0 +1,456 @@
package inference
import (
"context"
"fmt"
"iter"
"sort"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- test helpers ---
// resetBackends clears the global registry for test isolation.
func resetBackends(t *testing.T) {
t.Helper()
backendsMu.Lock()
defer backendsMu.Unlock()
backends = map[string]Backend{}
}
// stubBackend is a minimal Backend implementation for testing.
type stubBackend struct {
name string
available bool
loadErr error
}
func (s *stubBackend) Name() string { return s.name }
func (s *stubBackend) Available() bool { return s.available }
func (s *stubBackend) LoadModel(path string, opts ...LoadOption) (TextModel, error) {
if s.loadErr != nil {
return nil, s.loadErr
}
return &stubTextModel{backend: s.name, path: path}, nil
}
// stubTextModel is a minimal TextModel for testing LoadModel routing.
type stubTextModel struct {
backend string
path string
}
func (m *stubTextModel) Generate(_ context.Context, _ string, _ ...GenerateOption) iter.Seq[Token] {
return func(yield func(Token) bool) {}
}
func (m *stubTextModel) Chat(_ context.Context, _ []Message, _ ...GenerateOption) iter.Seq[Token] {
return func(yield func(Token) bool) {}
}
func (m *stubTextModel) Classify(_ context.Context, _ []string, _ ...GenerateOption) ([]ClassifyResult, error) {
return nil, nil
}
func (m *stubTextModel) BatchGenerate(_ context.Context, _ []string, _ ...GenerateOption) ([]BatchResult, error) {
return nil, nil
}
func (m *stubTextModel) ModelType() string { return "stub" }
func (m *stubTextModel) Info() ModelInfo { return ModelInfo{} }
func (m *stubTextModel) Metrics() GenerateMetrics { return GenerateMetrics{} }
func (m *stubTextModel) Err() error { return nil }
func (m *stubTextModel) Close() error { return nil }
// --- Register ---
func TestRegister_Good(t *testing.T) {
resetBackends(t)
b := &stubBackend{name: "test_backend", available: true}
Register(b)
got, ok := Get("test_backend")
require.True(t, ok)
assert.Equal(t, "test_backend", got.Name())
}
func TestRegister_Good_Multiple(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "alpha", available: true})
Register(&stubBackend{name: "beta", available: true})
Register(&stubBackend{name: "gamma", available: true})
names := List()
sort.Strings(names)
assert.Equal(t, []string{"alpha", "beta", "gamma"}, names)
}
func TestRegister_Ugly_Overwrites(t *testing.T) {
resetBackends(t)
b1 := &stubBackend{name: "dup", available: false}
b2 := &stubBackend{name: "dup", available: true}
Register(b1)
Register(b2)
got, ok := Get("dup")
require.True(t, ok)
assert.True(t, got.Available(), "second registration should overwrite the first")
}
// --- Get ---
func TestGet_Good(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "exists", available: true})
b, ok := Get("exists")
require.True(t, ok)
assert.Equal(t, "exists", b.Name())
}
func TestGet_Bad(t *testing.T) {
resetBackends(t)
b, ok := Get("nonexistent")
assert.False(t, ok)
assert.Nil(t, b)
}
// --- List ---
func TestList_Good_Empty(t *testing.T) {
resetBackends(t)
names := List()
assert.Empty(t, names)
}
func TestList_Good_Populated(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "a", available: true})
Register(&stubBackend{name: "b", available: true})
names := List()
sort.Strings(names)
assert.Equal(t, []string{"a", "b"}, names)
}
// --- Default ---
func TestDefault_Good_Metal(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "metal", b.Name(), "metal should be preferred when available")
}
func TestDefault_Good_Rocm(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "rocm", available: true})
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "rocm", b.Name(), "rocm should be preferred when metal is absent")
}
func TestDefault_Good_LlamaCpp(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "llama_cpp", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "llama_cpp", b.Name(), "llama_cpp should be used as fallback")
}
func TestDefault_Good_PriorityOrder(t *testing.T) {
// Full priority test: metal > rocm > llama_cpp
tests := []struct {
name string
backends []stubBackend
want string
}{
{
name: "all_available_prefers_metal",
backends: []stubBackend{
{name: "llama_cpp", available: true},
{name: "rocm", available: true},
{name: "metal", available: true},
},
want: "metal",
},
{
name: "metal_unavailable_prefers_rocm",
backends: []stubBackend{
{name: "metal", available: false},
{name: "rocm", available: true},
{name: "llama_cpp", available: true},
},
want: "rocm",
},
{
name: "metal_rocm_unavailable_prefers_llama_cpp",
backends: []stubBackend{
{name: "metal", available: false},
{name: "rocm", available: false},
{name: "llama_cpp", available: true},
},
want: "llama_cpp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resetBackends(t)
for i := range tt.backends {
Register(&tt.backends[i])
}
b, err := Default()
require.NoError(t, err)
assert.Equal(t, tt.want, b.Name())
})
}
}
func TestDefault_Good_FallbackToAny(t *testing.T) {
resetBackends(t)
// A custom backend that's not in the priority list.
Register(&stubBackend{name: "custom_gpu", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_gpu", b.Name(), "should fall back to any available backend")
}
func TestDefault_Bad_NoBackends(t *testing.T) {
resetBackends(t)
_, err := Default()
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestDefault_Bad_NoneAvailable(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: false})
Register(&stubBackend{name: "rocm", available: false})
_, err := Default()
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestDefault_Ugly_SkipsUnavailablePreferred(t *testing.T) {
resetBackends(t)
// metal registered but not available; custom_gpu is available.
Register(&stubBackend{name: "metal", available: false})
Register(&stubBackend{name: "custom_gpu", available: true})
b, err := Default()
require.NoError(t, err)
assert.Equal(t, "custom_gpu", b.Name(), "should skip unavailable preferred and use any available")
}
// --- LoadModel ---
func TestLoadModel_Good_DefaultBackend(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
m, err := LoadModel("/path/to/model")
require.NoError(t, err)
require.NotNil(t, m)
sm := m.(*stubTextModel)
assert.Equal(t, "metal", sm.backend)
assert.Equal(t, "/path/to/model", sm.path)
require.NoError(t, m.Close())
}
func TestLoadModel_Good_ExplicitBackend(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
Register(&stubBackend{name: "rocm", available: true})
m, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.NoError(t, err)
require.NotNil(t, m)
sm := m.(*stubTextModel)
assert.Equal(t, "rocm", sm.backend, "should use explicitly requested backend")
require.NoError(t, m.Close())
}
func TestLoadModel_Bad_NoBackends(t *testing.T) {
resetBackends(t)
_, err := LoadModel("/path/to/model")
require.Error(t, err)
assert.Contains(t, err.Error(), "no backends registered")
}
func TestLoadModel_Bad_ExplicitBackendNotRegistered(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "metal", available: true})
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.Error(t, err)
assert.Contains(t, err.Error(), "backend \"rocm\" not registered")
}
func TestLoadModel_Bad_ExplicitBackendNotAvailable(t *testing.T) {
resetBackends(t)
Register(&stubBackend{name: "rocm", available: false})
_, err := LoadModel("/path/to/model", WithBackend("rocm"))
require.Error(t, err)
assert.Contains(t, err.Error(), "backend \"rocm\" not available")
}
func TestLoadModel_Bad_BackendLoadError(t *testing.T) {
resetBackends(t)
Register(&stubBackend{
name: "broken",
available: true,
loadErr: fmt.Errorf("GPU out of memory"),
})
_, err := LoadModel("/path/to/model", WithBackend("broken"))
require.Error(t, err)
assert.Contains(t, err.Error(), "GPU out of memory")
}
func TestLoadModel_Good_PassesOptionsThrough(t *testing.T) {
resetBackends(t)
// Use a backend that captures the load path.
Register(&stubBackend{name: "metal", available: true})
m, err := LoadModel("/models/gemma3-1b",
WithContextLen(4096),
WithGPULayers(24),
)
require.NoError(t, err)
sm := m.(*stubTextModel)
assert.Equal(t, "/models/gemma3-1b", sm.path)
require.NoError(t, m.Close())
}
func TestLoadModel_Ugly_DefaultBackendLoadError(t *testing.T) {
resetBackends(t)
Register(&stubBackend{
name: "metal",
available: true,
loadErr: fmt.Errorf("model not found"),
})
_, err := LoadModel("/nonexistent/model")
require.Error(t, err)
assert.Contains(t, err.Error(), "model not found")
}
// --- Type assertions (compile-time checks) ---
func TestTypes_Good_InterfaceCompliance(t *testing.T) {
// Verify stubBackend implements Backend.
var _ Backend = (*stubBackend)(nil)
// Verify stubTextModel implements TextModel.
var _ TextModel = (*stubTextModel)(nil)
}
// --- Struct types ---
func TestToken_Good(t *testing.T) {
tok := Token{ID: 42, Text: "hello"}
assert.Equal(t, int32(42), tok.ID)
assert.Equal(t, "hello", tok.Text)
}
func TestMessage_Good(t *testing.T) {
msg := Message{Role: "user", Content: "Hi there"}
assert.Equal(t, "user", msg.Role)
assert.Equal(t, "Hi there", msg.Content)
}
func TestClassifyResult_Good(t *testing.T) {
cr := ClassifyResult{
Token: Token{ID: 1, Text: "yes"},
Logits: []float32{0.1, 0.9},
}
assert.Equal(t, int32(1), cr.Token.ID)
assert.Equal(t, "yes", cr.Token.Text)
assert.Len(t, cr.Logits, 2)
}
func TestBatchResult_Good(t *testing.T) {
br := BatchResult{
Tokens: []Token{{ID: 1, Text: "a"}, {ID: 2, Text: "b"}},
Err: nil,
}
assert.Len(t, br.Tokens, 2)
assert.NoError(t, br.Err)
}
func TestBatchResult_Bad(t *testing.T) {
br := BatchResult{
Tokens: nil,
Err: fmt.Errorf("OOM"),
}
assert.Nil(t, br.Tokens)
assert.Error(t, br.Err)
}
func TestModelInfo_Good(t *testing.T) {
info := ModelInfo{
Architecture: "gemma3",
VocabSize: 256128,
NumLayers: 26,
HiddenSize: 1152,
QuantBits: 4,
QuantGroup: 64,
}
assert.Equal(t, "gemma3", info.Architecture)
assert.Equal(t, 256128, info.VocabSize)
assert.Equal(t, 26, info.NumLayers)
assert.Equal(t, 1152, info.HiddenSize)
assert.Equal(t, 4, info.QuantBits)
assert.Equal(t, 64, info.QuantGroup)
}
func TestGenerateMetrics_Good(t *testing.T) {
m := GenerateMetrics{
PromptTokens: 100,
GeneratedTokens: 50,
PrefillTokensPerSec: 1000.0,
DecodeTokensPerSec: 25.0,
PeakMemoryBytes: 1 << 30, // 1 GiB
ActiveMemoryBytes: 512 << 20, // 512 MiB
}
assert.Equal(t, 100, m.PromptTokens)
assert.Equal(t, 50, m.GeneratedTokens)
assert.InDelta(t, 1000.0, m.PrefillTokensPerSec, 0.01)
assert.InDelta(t, 25.0, m.DecodeTokensPerSec, 0.01)
assert.Equal(t, uint64(1<<30), m.PeakMemoryBytes)
assert.Equal(t, uint64(512<<20), m.ActiveMemoryBytes)
}

350
options_test.go Normal file
View file

@ -0,0 +1,350 @@
package inference
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- GenerateConfig defaults ---
func TestDefaultGenerateConfig_Good(t *testing.T) {
cfg := DefaultGenerateConfig()
assert.Equal(t, 256, cfg.MaxTokens, "default MaxTokens should be 256")
assert.Equal(t, float32(0.0), cfg.Temperature, "default Temperature should be 0.0 (greedy)")
assert.Equal(t, 0, cfg.TopK, "default TopK should be 0 (disabled)")
assert.Equal(t, float32(0.0), cfg.TopP, "default TopP should be 0.0 (disabled)")
assert.Nil(t, cfg.StopTokens, "default StopTokens should be nil")
assert.Equal(t, float32(0.0), cfg.RepeatPenalty, "default RepeatPenalty should be 0.0 (disabled)")
assert.False(t, cfg.ReturnLogits, "default ReturnLogits should be false")
}
// --- WithMaxTokens ---
func TestWithMaxTokens_Good(t *testing.T) {
tests := []struct {
name string
val int
want int
}{
{"small", 32, 32},
{"medium", 512, 512},
{"large", 4096, 4096},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(tt.val)})
assert.Equal(t, tt.want, cfg.MaxTokens)
})
}
}
func TestWithMaxTokens_Bad(t *testing.T) {
// Zero and negative values are accepted (no validation in options layer)
cfg := ApplyGenerateOpts([]GenerateOption{WithMaxTokens(0)})
assert.Equal(t, 0, cfg.MaxTokens)
cfg = ApplyGenerateOpts([]GenerateOption{WithMaxTokens(-1)})
assert.Equal(t, -1, cfg.MaxTokens)
}
// --- WithTemperature ---
func TestWithTemperature_Good(t *testing.T) {
tests := []struct {
name string
val float32
want float32
}{
{"greedy", 0.0, 0.0},
{"low", 0.3, 0.3},
{"default_creative", 0.7, 0.7},
{"high", 1.5, 1.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTemperature(tt.val)})
assert.InDelta(t, tt.want, cfg.Temperature, 0.0001)
})
}
}
// --- WithTopK ---
func TestWithTopK_Good(t *testing.T) {
tests := []struct {
name string
val int
want int
}{
{"disabled", 0, 0},
{"small", 10, 10},
{"typical", 40, 40},
{"large", 100, 100},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTopK(tt.val)})
assert.Equal(t, tt.want, cfg.TopK)
})
}
}
// --- WithTopP ---
func TestWithTopP_Good(t *testing.T) {
tests := []struct {
name string
val float32
want float32
}{
{"disabled", 0.0, 0.0},
{"tight", 0.5, 0.5},
{"typical", 0.9, 0.9},
{"full", 1.0, 1.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithTopP(tt.val)})
assert.InDelta(t, tt.want, cfg.TopP, 0.0001)
})
}
}
// --- WithStopTokens ---
func TestWithStopTokens_Good(t *testing.T) {
t.Run("single", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1)})
assert.Equal(t, []int32{1}, cfg.StopTokens)
})
t.Run("multiple", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithStopTokens(1, 2, 3)})
assert.Equal(t, []int32{1, 2, 3}, cfg.StopTokens)
})
}
func TestWithStopTokens_Ugly(t *testing.T) {
// Last call wins — stop tokens are replaced, not merged.
cfg := ApplyGenerateOpts([]GenerateOption{
WithStopTokens(1, 2),
WithStopTokens(3, 4, 5),
})
assert.Equal(t, []int32{3, 4, 5}, cfg.StopTokens, "last WithStopTokens should win")
}
// --- WithRepeatPenalty ---
func TestWithRepeatPenalty_Good(t *testing.T) {
tests := []struct {
name string
val float32
want float32
}{
{"disabled", 0.0, 0.0},
{"no_penalty", 1.0, 1.0},
{"typical", 1.1, 1.1},
{"strong", 2.0, 2.0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithRepeatPenalty(tt.val)})
assert.InDelta(t, tt.want, cfg.RepeatPenalty, 0.0001)
})
}
}
// --- WithLogits ---
func TestWithLogits_Good(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{WithLogits()})
assert.True(t, cfg.ReturnLogits)
}
// --- ApplyGenerateOpts ---
func TestApplyGenerateOpts_Good(t *testing.T) {
t.Run("nil_opts_returns_defaults", func(t *testing.T) {
cfg := ApplyGenerateOpts(nil)
def := DefaultGenerateConfig()
assert.Equal(t, def, cfg)
})
t.Run("empty_opts_returns_defaults", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{})
def := DefaultGenerateConfig()
assert.Equal(t, def, cfg)
})
t.Run("all_options_combined", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{
WithMaxTokens(128),
WithTemperature(0.7),
WithTopK(40),
WithTopP(0.9),
WithStopTokens(1, 2),
WithRepeatPenalty(1.1),
WithLogits(),
})
assert.Equal(t, 128, cfg.MaxTokens)
assert.InDelta(t, 0.7, cfg.Temperature, 0.0001)
assert.Equal(t, 40, cfg.TopK)
assert.InDelta(t, 0.9, cfg.TopP, 0.0001)
assert.Equal(t, []int32{1, 2}, cfg.StopTokens)
assert.InDelta(t, 1.1, cfg.RepeatPenalty, 0.0001)
assert.True(t, cfg.ReturnLogits)
})
}
func TestApplyGenerateOpts_Ugly(t *testing.T) {
t.Run("last_option_wins", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{
WithMaxTokens(100),
WithMaxTokens(200),
WithMaxTokens(300),
})
assert.Equal(t, 300, cfg.MaxTokens, "last WithMaxTokens should win")
})
t.Run("temperature_override", func(t *testing.T) {
cfg := ApplyGenerateOpts([]GenerateOption{
WithTemperature(0.5),
WithTemperature(1.0),
})
assert.InDelta(t, 1.0, cfg.Temperature, 0.0001, "last WithTemperature should win")
})
}
// --- LoadConfig defaults ---
func TestApplyLoadOpts_Good_Defaults(t *testing.T) {
cfg := ApplyLoadOpts(nil)
assert.Equal(t, "", cfg.Backend, "default Backend should be empty (auto-detect)")
assert.Equal(t, 0, cfg.ContextLen, "default ContextLen should be 0 (model default)")
assert.Equal(t, -1, cfg.GPULayers, "default GPULayers should be -1 (all layers)")
assert.Equal(t, 0, cfg.ParallelSlots, "default ParallelSlots should be 0 (server default)")
}
// --- WithBackend ---
func TestWithBackend_Good(t *testing.T) {
tests := []struct {
name string
backend string
}{
{"metal", "metal"},
{"rocm", "rocm"},
{"llama_cpp", "llama_cpp"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithBackend(tt.backend)})
assert.Equal(t, tt.backend, cfg.Backend)
})
}
}
func TestWithBackend_Bad(t *testing.T) {
// Empty string is valid at the options layer (means auto-detect).
cfg := ApplyLoadOpts([]LoadOption{WithBackend("")})
assert.Equal(t, "", cfg.Backend)
}
// --- WithContextLen ---
func TestWithContextLen_Good(t *testing.T) {
tests := []struct {
name string
val int
want int
}{
{"small", 2048, 2048},
{"medium", 4096, 4096},
{"large", 32768, 32768},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithContextLen(tt.val)})
assert.Equal(t, tt.want, cfg.ContextLen)
})
}
}
// --- WithGPULayers ---
func TestWithGPULayers_Good(t *testing.T) {
tests := []struct {
name string
val int
want int
}{
{"all", -1, -1},
{"none", 0, 0},
{"partial", 24, 24},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(tt.val)})
assert.Equal(t, tt.want, cfg.GPULayers)
})
}
}
func TestWithGPULayers_Ugly(t *testing.T) {
// Override the default -1 with 0
cfg := ApplyLoadOpts([]LoadOption{WithGPULayers(0)})
assert.Equal(t, 0, cfg.GPULayers, "WithGPULayers(0) should override default -1")
}
// --- WithParallelSlots ---
func TestWithParallelSlots_Good(t *testing.T) {
tests := []struct {
name string
val int
want int
}{
{"default", 0, 0},
{"one", 1, 1},
{"four", 4, 4},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{WithParallelSlots(tt.val)})
assert.Equal(t, tt.want, cfg.ParallelSlots)
})
}
}
// --- ApplyLoadOpts combined ---
func TestApplyLoadOpts_Good_Combined(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{
WithBackend("rocm"),
WithContextLen(8192),
WithGPULayers(32),
WithParallelSlots(2),
})
assert.Equal(t, "rocm", cfg.Backend)
assert.Equal(t, 8192, cfg.ContextLen)
assert.Equal(t, 32, cfg.GPULayers)
assert.Equal(t, 2, cfg.ParallelSlots)
}
func TestApplyLoadOpts_Ugly(t *testing.T) {
t.Run("last_option_wins", func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{
WithBackend("metal"),
WithBackend("rocm"),
})
assert.Equal(t, "rocm", cfg.Backend, "last WithBackend should win")
})
t.Run("empty_slice_returns_defaults", func(t *testing.T) {
cfg := ApplyLoadOpts([]LoadOption{})
require.Equal(t, -1, cfg.GPULayers, "empty opts should keep default GPULayers=-1")
assert.Equal(t, "", cfg.Backend)
})
}