Implements inference.Backend via a Python subprocess communicating over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec. - bridge.py: embedded Python script wrapping mlx_lm.load() and mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit commands. Flushes stdout after every JSON line for streaming. - backend.go: Go subprocess manager. Extracts bridge.py from go:embed to temp file, spawns python3, pipes JSON requests. mlxlmModel implements full TextModel interface with mutex- serialised Generate/Chat, context cancellation with drain, and 2-second graceful Close with kill fallback. Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm. - backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed): name, load, generate, cancel, chat, close, error propagation, invalid path, auto-register, concurrent serialisation, classify/ batch unsupported, info, metrics, max_tokens limiting. All tests pass with -race. go vet clean. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
284 lines
7.5 KiB
Go
284 lines
7.5 KiB
Go
// SPDX-Licence-Identifier: EUPL-1.2
|
|
|
|
//go:build !nomlxlm
|
|
|
|
package mlxlm
|
|
|
|
import (
|
|
"context"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
|
|
"forge.lthn.ai/core/go-inference"
|
|
)
|
|
|
|
// mockScript returns the absolute path to testdata/mock_bridge.py.
|
|
func mockScript(t *testing.T) string {
|
|
t.Helper()
|
|
_, file, _, ok := runtime.Caller(0)
|
|
if !ok {
|
|
t.Fatal("cannot determine test file path")
|
|
}
|
|
return filepath.Join(filepath.Dir(file), "testdata", "mock_bridge.py")
|
|
}
|
|
|
|
// loadMock spawns a model backed by the mock Python script.
|
|
func loadMock(t *testing.T, modelPath string) inference.TextModel {
|
|
t.Helper()
|
|
m, err := loadModel(context.Background(), modelPath, mockScript(t))
|
|
if err != nil {
|
|
t.Fatalf("loadModel: %v", err)
|
|
}
|
|
t.Cleanup(func() { m.Close() })
|
|
return m
|
|
}
|
|
|
|
// (a) Name returns "mlx_lm".
|
|
func TestBackendName(t *testing.T) {
|
|
b := &mlxlmBackend{}
|
|
if got := b.Name(); got != "mlx_lm" {
|
|
t.Errorf("Name() = %q, want %q", got, "mlx_lm")
|
|
}
|
|
}
|
|
|
|
// (b) LoadModel spawns subprocess, sends load command, gets response.
|
|
func TestLoadModel_Good(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
if m.ModelType() != "mock_model" {
|
|
t.Errorf("ModelType() = %q, want %q", m.ModelType(), "mock_model")
|
|
}
|
|
}
|
|
|
|
// (c) Generate streams tokens from subprocess, all tokens received.
|
|
func TestGenerate_Good(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx := context.Background()
|
|
var tokens []inference.Token
|
|
for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
|
tokens = append(tokens, tok)
|
|
}
|
|
if err := m.Err(); err != nil {
|
|
t.Fatalf("Err() = %v", err)
|
|
}
|
|
if len(tokens) != 5 {
|
|
t.Fatalf("got %d tokens, want 5", len(tokens))
|
|
}
|
|
|
|
// Verify token content matches mock.
|
|
expected := []string{"Hello", " ", "world", "!", "\n"}
|
|
for i, tok := range tokens {
|
|
if tok.Text != expected[i] {
|
|
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
|
|
}
|
|
wantID := int32(100 + i)
|
|
if tok.ID != wantID {
|
|
t.Errorf("token[%d].ID = %d, want %d", i, tok.ID, wantID)
|
|
}
|
|
}
|
|
}
|
|
|
|
// (d) Generate with context cancellation stops early.
|
|
func TestGenerate_Cancel(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
var count int
|
|
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
|
count++
|
|
if count >= 2 {
|
|
cancel()
|
|
}
|
|
}
|
|
// We should have received at most a few tokens before cancellation took effect.
|
|
if count > 5 {
|
|
t.Errorf("expected early stop, got %d tokens", count)
|
|
}
|
|
if err := m.Err(); err != context.Canceled {
|
|
t.Logf("Err() = %v (expected context.Canceled)", err)
|
|
}
|
|
}
|
|
|
|
// (e) Chat formats messages correctly and streams tokens.
|
|
func TestChat_Good(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx := context.Background()
|
|
var tokens []inference.Token
|
|
for tok := range m.Chat(ctx, []inference.Message{
|
|
{Role: "user", Content: "Hi there"},
|
|
}, inference.WithMaxTokens(5)) {
|
|
tokens = append(tokens, tok)
|
|
}
|
|
if err := m.Err(); err != nil {
|
|
t.Fatalf("Err() = %v", err)
|
|
}
|
|
if len(tokens) != 5 {
|
|
t.Fatalf("got %d tokens, want 5", len(tokens))
|
|
}
|
|
|
|
// Mock chat returns "I heard you".
|
|
expected := []string{"I", " ", "heard", " ", "you"}
|
|
for i, tok := range tokens {
|
|
if tok.Text != expected[i] {
|
|
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
// (f) Close kills subprocess cleanly.
|
|
func TestClose_Good(t *testing.T) {
|
|
m, err := loadModel(context.Background(), "/fake/model/path", mockScript(t))
|
|
if err != nil {
|
|
t.Fatalf("loadModel: %v", err)
|
|
}
|
|
// Close should not error.
|
|
if err := m.Close(); err != nil {
|
|
t.Errorf("Close() = %v", err)
|
|
}
|
|
}
|
|
|
|
// (g) Err returns error on subprocess failure.
|
|
func TestGenerate_Error(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx := context.Background()
|
|
var count int
|
|
for range m.Generate(ctx, "ERROR trigger", inference.WithMaxTokens(5)) {
|
|
count++
|
|
}
|
|
if count != 0 {
|
|
t.Errorf("expected 0 tokens on error, got %d", count)
|
|
}
|
|
if err := m.Err(); err == nil {
|
|
t.Fatal("expected non-nil Err()")
|
|
} else if !strings.Contains(err.Error(), "simulated model error") {
|
|
t.Errorf("Err() = %q, want to contain %q", err.Error(), "simulated model error")
|
|
}
|
|
}
|
|
|
|
// (h) LoadModel with invalid path returns error.
|
|
func TestLoadModel_Bad(t *testing.T) {
|
|
_, err := loadModel(context.Background(), "/path/with/FAIL/in/it", mockScript(t))
|
|
if err == nil {
|
|
t.Fatal("expected error for FAIL path")
|
|
}
|
|
if !strings.Contains(err.Error(), "cannot open model") {
|
|
t.Errorf("error = %q, want to contain %q", err.Error(), "cannot open model")
|
|
}
|
|
}
|
|
|
|
// (i) Backend auto-registers (check inference.Get("mlx_lm")).
|
|
func TestAutoRegister(t *testing.T) {
|
|
b, ok := inference.Get("mlx_lm")
|
|
if !ok {
|
|
t.Fatal("mlx_lm backend not registered")
|
|
}
|
|
if b.Name() != "mlx_lm" {
|
|
t.Errorf("Name() = %q, want %q", b.Name(), "mlx_lm")
|
|
}
|
|
}
|
|
|
|
// (j) Concurrent Generate calls are serialised (mu lock).
|
|
func TestGenerate_Concurrent(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx := context.Background()
|
|
const goroutines = 3
|
|
var wg sync.WaitGroup
|
|
wg.Add(goroutines)
|
|
|
|
results := make([]int, goroutines)
|
|
for i := range goroutines {
|
|
go func(idx int) {
|
|
defer wg.Done()
|
|
var count int
|
|
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
|
count++
|
|
}
|
|
results[idx] = count
|
|
}(i)
|
|
}
|
|
wg.Wait()
|
|
|
|
// Each goroutine should have received all 5 tokens (serialised execution).
|
|
for i, count := range results {
|
|
if count != 5 {
|
|
t.Errorf("goroutine %d got %d tokens, want 5", i, count)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Additional: Classify returns unsupported error.
|
|
func TestClassify_Unsupported(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
_, err := m.Classify(context.Background(), []string{"test"})
|
|
if err == nil {
|
|
t.Fatal("expected error from Classify")
|
|
}
|
|
if !strings.Contains(err.Error(), "not supported") {
|
|
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
|
|
}
|
|
}
|
|
|
|
// Additional: BatchGenerate returns unsupported error.
|
|
func TestBatchGenerate_Unsupported(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
_, err := m.BatchGenerate(context.Background(), []string{"test"})
|
|
if err == nil {
|
|
t.Fatal("expected error from BatchGenerate")
|
|
}
|
|
if !strings.Contains(err.Error(), "not supported") {
|
|
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
|
|
}
|
|
}
|
|
|
|
// Additional: Info returns model metadata.
|
|
func TestInfo_Good(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
info := m.Info()
|
|
if info.Architecture != "mock_model" {
|
|
t.Errorf("Architecture = %q, want %q", info.Architecture, "mock_model")
|
|
}
|
|
if info.VocabSize != 32000 {
|
|
t.Errorf("VocabSize = %d, want %d", info.VocabSize, 32000)
|
|
}
|
|
if info.NumLayers != 24 {
|
|
t.Errorf("NumLayers = %d, want %d", info.NumLayers, 24)
|
|
}
|
|
if info.HiddenSize != 2048 {
|
|
t.Errorf("HiddenSize = %d, want %d", info.HiddenSize, 2048)
|
|
}
|
|
}
|
|
|
|
// Additional: Metrics returns zero values (not tracked by subprocess).
|
|
func TestMetrics_Zero(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
met := m.Metrics()
|
|
if met.PromptTokens != 0 || met.GeneratedTokens != 0 {
|
|
t.Errorf("expected zero metrics, got prompt=%d generated=%d",
|
|
met.PromptTokens, met.GeneratedTokens)
|
|
}
|
|
}
|
|
|
|
// Additional: Generate with fewer max_tokens than available tokens.
|
|
func TestGenerate_MaxTokens(t *testing.T) {
|
|
m := loadMock(t, "/fake/model/path")
|
|
|
|
ctx := context.Background()
|
|
var count int
|
|
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(3)) {
|
|
count++
|
|
}
|
|
if err := m.Err(); err != nil {
|
|
t.Fatalf("Err() = %v", err)
|
|
}
|
|
if count != 3 {
|
|
t.Errorf("got %d tokens, want 3", count)
|
|
}
|
|
}
|