test: add agent subsystem tests (25 tests, 528 LOC)

Tests for AdapterMeta (model family → tag/prefix/stem mapping),
FindUnscored (checkpoint filtering), BufferInfluxResult/ReplayInfluxBuffer
(JSONL persistence round-trip), and DiscoverCheckpoints (SSH output parsing
via fakeTransport mock implementing RemoteTransport interface).

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-20 03:23:13 +00:00
parent 987831bd16
commit 3e227610fe

528
agent_test.go Normal file
View file

@ -0,0 +1,528 @@
// SPDX-Licence-Identifier: EUPL-1.2
package ml
import (
"context"
"encoding/json"
"errors"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// fakeTransport — in-memory RemoteTransport for testing
// ---------------------------------------------------------------------------
// fakeTransport implements RemoteTransport using canned responses keyed on
// a substring of the command string. Commands are matched in insertion order
// so the first matching key wins.
type fakeTransport struct {
commands []fakeCmd
}
type fakeCmd struct {
pattern string
stdout string
err error
}
func newFakeTransport() *fakeTransport { return &fakeTransport{} }
func (f *fakeTransport) On(pattern, stdout string, err error) {
f.commands = append(f.commands, fakeCmd{pattern: pattern, stdout: stdout, err: err})
}
func (f *fakeTransport) Run(_ context.Context, cmd string) (string, error) {
for _, fc := range f.commands {
if contains(cmd, fc.pattern) {
return fc.stdout, fc.err
}
}
return "", errors.New("fakeTransport: no match for command: " + cmd)
}
func (f *fakeTransport) CopyFrom(_ context.Context, _, _ string) error { return nil }
func (f *fakeTransport) CopyTo(_ context.Context, _, _ string) error { return nil }
// contains is a small helper to avoid importing strings just for this.
func contains(s, substr string) bool {
return len(substr) == 0 || len(s) >= len(substr) && searchSubstr(s, substr)
}
func searchSubstr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}
// =========================================================================
// 1. AdapterMeta tests
// =========================================================================
func TestAdapterMeta_KnownFamilies_Good(t *testing.T) {
tests := []struct {
dirname string
wantTag string
wantPfx string
wantStem string
}{
// gemma-3-1b via "1b" prefix
{"adapters-1b", "gemma-3-1b", "G1", "1b"},
// gemma-3-27b via "27b" prefix
{"adapters-27b", "gemma-3-27b", "G27", "27b"},
// deepseek-r1-7b
{"adapters-deepseek-r1-7b", "deepseek-r1-7b", "R1", "deepseek-r1-7b"},
// gpt-oss
{"adapters-gpt-oss", "gpt-oss-20b", "GPT", "gpt-oss"},
// gemma-3-12b via "12b" prefix
{"adapters-12b", "gemma-3-12b", "G12", "12b"},
// gemma-3-4b via "4b" prefix
{"adapters-4b", "gemma-3-4b", "G4", "4b"},
// bench-1b
{"adapters-bench-1b", "gemma-3-1b", "B1", "bench-1b"},
// book
{"adapters-book", "gemma-3-27b", "Book", "book"},
// cross
{"adapters-cross", "gemma-3-12b", "Cross", "cross"},
// vi → gemma-3-1b
{"adapters-vi", "gemma-3-1b", "Vi1", "vi"},
// vi-12b → gemma-3-12b
{"adapters-vi-12b", "gemma-3-12b", "Vi12", "vi-12b"},
// lem-gpt-oss
{"adapters-lem-gpt-oss", "gpt-oss-20b", "LGPT", "lem-gpt-oss"},
}
for _, tt := range tests {
t.Run(tt.dirname, func(t *testing.T) {
tag, pfx, stem := AdapterMeta(tt.dirname)
assert.Equal(t, tt.wantTag, tag, "model tag")
assert.Equal(t, tt.wantPfx, pfx, "label prefix")
assert.Equal(t, tt.wantStem, stem, "run ID stem")
})
}
}
func TestAdapterMeta_WithVariant_Good(t *testing.T) {
// "adapters-27b-reasoning" → 27b prefix matches, variant = "reasoning"
tag, pfx, stem := AdapterMeta("adapters-27b-reasoning")
assert.Equal(t, "gemma-3-27b", tag)
assert.Equal(t, "G27-reasoning", pfx)
assert.Equal(t, "27b-reasoning", stem)
}
func TestAdapterMeta_WithoutVariant_Good(t *testing.T) {
// "adapters-12b" → variant is empty → "base"
tag, pfx, stem := AdapterMeta("adapters-12b")
assert.Equal(t, "gemma-3-12b", tag)
assert.Equal(t, "G12", pfx) // variant="base" produces short without suffix
assert.Equal(t, "12b", stem)
}
func TestAdapterMeta_SubdirectoryPattern_Good(t *testing.T) {
// "adapters-15k/gemma-3-27b" → matches "15k/gemma-3-27b" prefix
tag, pfx, stem := AdapterMeta("adapters-15k/gemma-3-27b")
assert.Equal(t, "gemma-3-27b", tag)
assert.Equal(t, "G27", pfx)
// stem should replace "/" with "-"
assert.Equal(t, "15k-gemma-3-27b", stem)
}
func TestAdapterMeta_SubdirectoryWithVariant_Good(t *testing.T) {
// "adapters-15k/gemma-3-1b-creative" → variant = "creative"
tag, pfx, stem := AdapterMeta("adapters-15k/gemma-3-1b-creative")
assert.Equal(t, "gemma-3-1b", tag)
assert.Equal(t, "G1-creative", pfx)
assert.Equal(t, "15k-gemma-3-1b-creative", stem)
}
func TestAdapterMeta_Unknown_Bad(t *testing.T) {
// Unknown dirname falls back: tag=name, short=name[:10], stem=name
tag, pfx, stem := AdapterMeta("adapters-completelynewmodel42")
assert.Equal(t, "completelynewmodel42", tag)
assert.Equal(t, "completely", pfx) // truncated to 10 chars
assert.Equal(t, "completelynewmodel42", stem)
}
func TestAdapterMeta_UnknownShort_Good(t *testing.T) {
// Short unknown name (< 10 chars) is not truncated.
tag, pfx, stem := AdapterMeta("adapters-xyz")
assert.Equal(t, "xyz", tag)
assert.Equal(t, "xyz", pfx)
assert.Equal(t, "xyz", stem)
}
func TestAdapterMeta_NoPrefix_Good(t *testing.T) {
// dirname without "adapters-" prefix — TrimPrefix does nothing useful,
// but the function should still handle it gracefully.
tag, pfx, stem := AdapterMeta("27b-fancy")
assert.Equal(t, "gemma-3-27b", tag)
assert.Equal(t, "G27-fancy", pfx)
assert.Equal(t, "27b-fancy", stem)
}
// =========================================================================
// 2. FindUnscored tests
// =========================================================================
func TestFindUnscored_AllUnscored_Good(t *testing.T) {
checkpoints := []Checkpoint{
{Dirname: "b-dir", Iteration: 200, RunID: "run-b", Label: "B @200"},
{Dirname: "a-dir", Iteration: 100, RunID: "run-a", Label: "A @100"},
{Dirname: "a-dir", Iteration: 50, RunID: "run-a", Label: "A @50"},
}
scored := map[[2]string]bool{}
result := FindUnscored(checkpoints, scored)
require.Len(t, result, 3)
// Should be sorted by (dirname, iteration)
assert.Equal(t, "a-dir", result[0].Dirname)
assert.Equal(t, 50, result[0].Iteration)
assert.Equal(t, "a-dir", result[1].Dirname)
assert.Equal(t, 100, result[1].Iteration)
assert.Equal(t, "b-dir", result[2].Dirname)
assert.Equal(t, 200, result[2].Iteration)
}
func TestFindUnscored_SomeScored_Good(t *testing.T) {
checkpoints := []Checkpoint{
{Dirname: "dir", Iteration: 100, RunID: "run-1", Label: "L @100"},
{Dirname: "dir", Iteration: 200, RunID: "run-1", Label: "L @200"},
{Dirname: "dir", Iteration: 300, RunID: "run-1", Label: "L @300"},
}
scored := map[[2]string]bool{
{"run-1", "L @100"}: true,
{"run-1", "L @300"}: true,
}
result := FindUnscored(checkpoints, scored)
require.Len(t, result, 1)
assert.Equal(t, 200, result[0].Iteration)
assert.Equal(t, "L @200", result[0].Label)
}
func TestFindUnscored_AllScored_Good(t *testing.T) {
checkpoints := []Checkpoint{
{Dirname: "dir", Iteration: 100, RunID: "run-1", Label: "L @100"},
{Dirname: "dir", Iteration: 200, RunID: "run-1", Label: "L @200"},
}
scored := map[[2]string]bool{
{"run-1", "L @100"}: true,
{"run-1", "L @200"}: true,
}
result := FindUnscored(checkpoints, scored)
assert.Empty(t, result)
}
func TestFindUnscored_EmptyInput_Good(t *testing.T) {
result := FindUnscored(nil, nil)
assert.Empty(t, result)
result = FindUnscored([]Checkpoint{}, map[[2]string]bool{})
assert.Empty(t, result)
}
func TestFindUnscored_NilScored_Good(t *testing.T) {
// nil scored map should treat everything as unscored
checkpoints := []Checkpoint{
{Dirname: "a", Iteration: 1, RunID: "r", Label: "L @1"},
}
result := FindUnscored(checkpoints, nil)
require.Len(t, result, 1)
}
// =========================================================================
// 3. BufferInfluxResult / ReplayInfluxBuffer round-trip tests
// =========================================================================
func TestBufferInfluxResult_RoundTrip_Good(t *testing.T) {
workDir := t.TempDir()
cp := Checkpoint{
RemoteDir: "/data/adapters-27b",
Filename: "0001000_adapters.safetensors",
Dirname: "adapters-27b",
Iteration: 1000,
ModelTag: "gemma-3-27b",
Label: "G27 @1000",
RunID: "27b-capability-auto",
}
results := ProbeResult{
Accuracy: 75.0,
Correct: 3,
Total: 4,
ByCategory: map[string]CategoryResult{
"math": {Correct: 2, Total: 2},
"lang": {Correct: 1, Total: 2},
},
Probes: map[string]SingleProbeResult{
"p1": {Passed: true, Response: "ok"},
"p2": {Passed: false, Response: "wrong"},
},
}
BufferInfluxResult(workDir, cp, results)
// Verify the buffer file exists and contains valid JSONL
bufPath := filepath.Join(workDir, InfluxBufferFile)
data, err := os.ReadFile(bufPath)
require.NoError(t, err)
assert.NotEmpty(t, data)
// Parse the JSONL entry and verify fields
var entry bufferEntry
err = json.Unmarshal(data[:len(data)-1], &entry) // trim trailing newline
require.NoError(t, err)
assert.Equal(t, cp.Label, entry.Checkpoint.Label)
assert.Equal(t, cp.ModelTag, entry.Checkpoint.ModelTag)
assert.Equal(t, cp.RunID, entry.Checkpoint.RunID)
assert.Equal(t, results.Accuracy, entry.Results.Accuracy)
assert.Equal(t, results.Correct, entry.Results.Correct)
assert.Equal(t, results.Total, entry.Results.Total)
assert.NotEmpty(t, entry.Timestamp)
}
func TestBufferInfluxResult_MultipleEntries_Good(t *testing.T) {
workDir := t.TempDir()
for i := 0; i < 3; i++ {
cp := Checkpoint{
Dirname: "dir",
Iteration: i * 100,
Label: "L",
RunID: "run",
ModelTag: "tag",
}
results := ProbeResult{
Accuracy: float64(i) * 25.0,
Correct: i,
Total: 4,
Probes: map[string]SingleProbeResult{},
}
BufferInfluxResult(workDir, cp, results)
}
bufPath := filepath.Join(workDir, InfluxBufferFile)
data, err := os.ReadFile(bufPath)
require.NoError(t, err)
// Count newlines — should be 3 JSONL lines
lines := 0
for _, b := range data {
if b == '\n' {
lines++
}
}
assert.Equal(t, 3, lines)
}
func TestReplayInfluxBuffer_EmptyFile_Good(t *testing.T) {
workDir := t.TempDir()
// No buffer file exists — ReplayInfluxBuffer should be a no-op
ReplayInfluxBuffer(workDir, nil)
// Buffer file still shouldn't exist
_, err := os.Stat(filepath.Join(workDir, InfluxBufferFile))
assert.True(t, os.IsNotExist(err))
}
func TestReplayInfluxBuffer_MissingFile_Good(t *testing.T) {
// Calling with a nonexistent directory should not panic
ReplayInfluxBuffer("/nonexistent/path/that/does/not/exist", nil)
}
// =========================================================================
// 4. DiscoverCheckpoints tests (using fakeTransport)
// =========================================================================
func TestDiscoverCheckpoints_HappyPath_Good(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
// Command 1: list adapter directories (exact command from DiscoverCheckpoints)
ft.On("ls -d "+base+"/adapters-* 2>/dev/null",
base+"/adapters-27b\n"+base+"/adapters-1b\n", nil)
// Command 2a: sub-directory check for adapters-27b — no gemma-3-* subdirs
ft.On("ls -d "+base+"/adapters-27b/gemma-3-* 2>/dev/null", "", errors.New("no match"))
// Command 2b: sub-directory check for adapters-1b — no gemma-3-* subdirs
ft.On("ls -d "+base+"/adapters-1b/gemma-3-* 2>/dev/null", "", errors.New("no match"))
// Command 3a: list safetensors in adapters-27b
ft.On("ls "+base+"/adapters-27b/*_adapters.safetensors 2>/dev/null",
base+"/adapters-27b/0001000_adapters.safetensors\n"+base+"/adapters-27b/0002000_adapters.safetensors\n", nil)
// Command 3b: list safetensors in adapters-1b
ft.On("ls "+base+"/adapters-1b/*_adapters.safetensors 2>/dev/null",
base+"/adapters-1b/0000500_adapters.safetensors\n", nil)
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
}
checkpoints, err := DiscoverCheckpoints(cfg)
require.NoError(t, err)
require.Len(t, checkpoints, 3)
// Verify parsed checkpoint details
found1000 := false
found2000 := false
found500 := false
for _, cp := range checkpoints {
switch {
case cp.Dirname == "adapters-27b" && cp.Iteration == 1000:
found1000 = true
assert.Equal(t, "gemma-3-27b", cp.ModelTag)
assert.Equal(t, "0001000_adapters.safetensors", cp.Filename)
assert.Contains(t, cp.Label, "@0001000")
assert.Contains(t, cp.RunID, "27b")
case cp.Dirname == "adapters-27b" && cp.Iteration == 2000:
found2000 = true
case cp.Dirname == "adapters-1b" && cp.Iteration == 500:
found500 = true
assert.Equal(t, "gemma-3-1b", cp.ModelTag)
}
}
assert.True(t, found1000, "should find iteration 1000")
assert.True(t, found2000, "should find iteration 2000")
assert.True(t, found500, "should find iteration 500")
}
func TestDiscoverCheckpoints_WithSubDirs_Good(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
// Command 1: list adapter directories
ft.On("ls -d "+base+"/adapters-* 2>/dev/null",
base+"/adapters-15k\n", nil)
// Command 2: sub-directory check finds gemma-3-* subdirs
ft.On("ls -d "+base+"/adapters-15k/gemma-3-* 2>/dev/null",
base+"/adapters-15k/gemma-3-27b\n"+base+"/adapters-15k/gemma-3-1b\n", nil)
// Command 3a: list safetensors in gemma-3-27b subdir
ft.On("ls "+base+"/adapters-15k/gemma-3-27b/*_adapters.safetensors 2>/dev/null",
base+"/adapters-15k/gemma-3-27b/0003000_adapters.safetensors\n", nil)
// Command 3b: list safetensors in gemma-3-1b subdir
ft.On("ls "+base+"/adapters-15k/gemma-3-1b/*_adapters.safetensors 2>/dev/null",
base+"/adapters-15k/gemma-3-1b/0001500_adapters.safetensors\n", nil)
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
}
checkpoints, err := DiscoverCheckpoints(cfg)
require.NoError(t, err)
require.Len(t, checkpoints, 2)
// The dirname should include the subdirectory path relative to base
for _, cp := range checkpoints {
switch {
case cp.Iteration == 3000:
assert.Equal(t, "adapters-15k/gemma-3-27b", cp.Dirname)
assert.Equal(t, "gemma-3-27b", cp.ModelTag)
case cp.Iteration == 1500:
assert.Equal(t, "adapters-15k/gemma-3-1b", cp.Dirname)
assert.Equal(t, "gemma-3-1b", cp.ModelTag)
default:
t.Errorf("unexpected iteration %d", cp.Iteration)
}
}
}
func TestDiscoverCheckpoints_NoAdapters_Good(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
// ls -d returns empty output
ft.On("ls -d "+base+"/adapters-* 2>/dev/null", "", nil)
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
}
checkpoints, err := DiscoverCheckpoints(cfg)
require.NoError(t, err)
assert.Empty(t, checkpoints)
}
func TestDiscoverCheckpoints_SSHError_Bad(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
ft.On("ls -d "+base+"/adapters-* 2>/dev/null", "", errors.New("ssh: connection refused"))
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
}
_, err := DiscoverCheckpoints(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "list adapter dirs")
}
func TestDiscoverCheckpoints_FilterPattern_Good(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
// When Filter is set, the ls pattern changes to adapters-27b*
ft.On("ls -d "+base+"/adapters-27b* 2>/dev/null",
base+"/adapters-27b\n", nil)
// No gemma-3-* subdirs
ft.On("ls -d "+base+"/adapters-27b/gemma-3-* 2>/dev/null", "", errors.New("no match"))
ft.On("ls "+base+"/adapters-27b/*_adapters.safetensors 2>/dev/null",
base+"/adapters-27b/0001000_adapters.safetensors\n", nil)
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
Filter: "27b",
}
checkpoints, err := DiscoverCheckpoints(cfg)
require.NoError(t, err)
require.Len(t, checkpoints, 1)
assert.Equal(t, 1000, checkpoints[0].Iteration)
}
func TestDiscoverCheckpoints_NoSafetensors_Good(t *testing.T) {
ft := newFakeTransport()
base := "/data/training"
ft.On("ls -d "+base+"/adapters-* 2>/dev/null",
base+"/adapters-27b\n", nil)
ft.On("ls -d "+base+"/adapters-27b/gemma-3-* 2>/dev/null", "", errors.New("no match"))
// safetensors listing fails (no checkpoint files yet)
ft.On("ls "+base+"/adapters-27b/*_adapters.safetensors 2>/dev/null", "", errors.New("no match"))
cfg := &AgentConfig{
M3AdapterBase: base,
Transport: ft,
}
checkpoints, err := DiscoverCheckpoints(cfg)
require.NoError(t, err)
assert.Empty(t, checkpoints, "no safetensors means no checkpoints")
}