From 757a241f593f1f00d7b1693bc0a97a54c87787f6 Mon Sep 17 00:00:00 2001 From: Snider Date: Fri, 20 Feb 2026 09:02:30 +0000 Subject: [PATCH] =?UTF-8?q?feat(mlxlm):=20Phase=205.5=20=E2=80=94=20subpro?= =?UTF-8?q?cess=20backend=20using=20Python=20mlx-lm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements inference.Backend via a Python subprocess communicating over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec. - bridge.py: embedded Python script wrapping mlx_lm.load() and mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit commands. Flushes stdout after every JSON line for streaming. - backend.go: Go subprocess manager. Extracts bridge.py from go:embed to temp file, spawns python3, pipes JSON requests. mlxlmModel implements full TextModel interface with mutex- serialised Generate/Chat, context cancellation with drain, and 2-second graceful Close with kill fallback. Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm. - backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed): name, load, generate, cancel, chat, close, error propagation, invalid path, auto-register, concurrent serialisation, classify/ batch unsupported, info, metrics, max_tokens limiting. All tests pass with -race. go vet clean. Co-Authored-By: Virgil Co-Authored-By: Claude Opus 4.6 --- TODO.md | 8 +- mlxlm/backend.go | 459 ++++++++++++++++++++++++++++++++++ mlxlm/backend_test.go | 284 +++++++++++++++++++++ mlxlm/bridge.py | 224 +++++++++++++++++ mlxlm/testdata/mock_bridge.py | 118 +++++++++ 5 files changed, 1089 insertions(+), 4 deletions(-) create mode 100644 mlxlm/backend.go create mode 100644 mlxlm/backend_test.go create mode 100644 mlxlm/bridge.py create mode 100644 mlxlm/testdata/mock_bridge.py diff --git a/TODO.md b/TODO.md index 4f005ff..92d2777 100644 --- a/TODO.md +++ b/TODO.md @@ -55,7 +55,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no #### 5.5.1 Python Helper Script (`mlxlm/bridge.py`) -- [ ] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`: +- [x] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`: - Reads JSON Lines from stdin, writes JSON Lines to stdout - **`load` command**: `{"cmd":"load","path":"/path/to/model","max_tokens":128}` → loads model + tokenizer, responds `{"ok":true,"model_type":"gemma3","vocab_size":262144}` - **`generate` command**: `{"cmd":"generate","prompt":"Hello","max_tokens":128,"temperature":0.7,"top_k":40,"top_p":0.9}` → streams `{"token":"Hello","token_id":123}` per token, then `{"done":true,"tokens_generated":42}` @@ -68,7 +68,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no #### 5.5.2 Go Backend (`mlxlm/backend.go`) -- [ ] **Create `mlxlm/backend.go`** — Go subprocess wrapper: +- [x] **Create `mlxlm/backend.go`** — Go subprocess wrapper: - `type mlxlmBackend struct{}` — implements `mlx.Backend` - `func init()` — auto-registers via `mlx.Register(&mlxlmBackend{})`. Use build tag `//go:build !nomxlm` so consumers can opt out - `func (b *mlxlmBackend) Name() string` — returns `"mlx_lm"` @@ -95,13 +95,13 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no #### 5.5.3 Tests (`mlxlm/backend_test.go`) -- [ ] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that: +- [x] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that: - Responds to `load` with fake model info - Responds to `generate` with 5 fixed tokens then `done` - Responds to `info` with fake metadata - Responds to `chat` same as generate - Tests don't need `mlx_lm` installed — pure mock -- [ ] **Tests** using mock script: +- [x] **Tests** using mock script: - (a) `Name()` returns `"mlx_lm"` - (b) `LoadModel` spawns subprocess, sends load command, gets response - (c) `Generate` streams tokens from subprocess, all tokens received diff --git a/mlxlm/backend.go b/mlxlm/backend.go new file mode 100644 index 0000000..a8de7ed --- /dev/null +++ b/mlxlm/backend.go @@ -0,0 +1,459 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +//go:build !nomlxlm + +// Package mlxlm provides a subprocess-based inference backend using Python's mlx-lm. +// +// It implements the [inference.Backend] interface by spawning a Python process +// that communicates over JSON Lines (one JSON object per line on stdin/stdout). +// This allows using mlx-lm models without CGO or native Metal bindings. +// +// The backend auto-registers as "mlx_lm" via init(). Consumers can opt out +// with the build tag "nomlxlm". +// +// # Usage +// +// import _ "forge.lthn.ai/core/go-mlx/mlxlm" +// +// m, err := inference.LoadModel("/path/to/model", inference.WithBackend("mlx_lm")) +// defer m.Close() +// +// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(64)) { +// fmt.Print(tok.Text) +// } +package mlxlm + +import ( + "bufio" + "context" + "embed" + "encoding/json" + "errors" + "fmt" + "io" + "iter" + "os" + "os/exec" + "path/filepath" + "sync" + "time" + + "forge.lthn.ai/core/go-inference" +) + +//go:embed bridge.py +var bridgeFS embed.FS + +// scriptPath holds the extracted bridge.py temp path. Created once per process. +var ( + scriptOnce sync.Once + scriptPath string + scriptErr error +) + +// extractScript writes the embedded bridge.py to a temp file and returns its path. +func extractScript() (string, error) { + scriptOnce.Do(func() { + data, err := bridgeFS.ReadFile("bridge.py") + if err != nil { + scriptErr = fmt.Errorf("mlxlm: read embedded bridge.py: %w", err) + return + } + dir, err := os.MkdirTemp("", "mlxlm-*") + if err != nil { + scriptErr = fmt.Errorf("mlxlm: create temp dir: %w", err) + return + } + p := filepath.Join(dir, "bridge.py") + if err := os.WriteFile(p, data, 0o644); err != nil { + scriptErr = fmt.Errorf("mlxlm: write bridge.py: %w", err) + return + } + scriptPath = p + }) + return scriptPath, scriptErr +} + +func init() { + inference.Register(&mlxlmBackend{}) +} + +// mlxlmBackend implements inference.Backend for mlx-lm via subprocess. +type mlxlmBackend struct{} + +func (b *mlxlmBackend) Name() string { return "mlx_lm" } + +// Available reports whether python3 is found on PATH. +func (b *mlxlmBackend) Available() bool { + _, err := exec.LookPath("python3") + return err == nil +} + +// LoadModel spawns a Python subprocess running bridge.py, loads the model, +// and returns a TextModel backed by the subprocess. +func (b *mlxlmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) { + return loadModel(context.Background(), path, "", opts...) +} + +// loadModel is the internal implementation, accepting an optional scriptOverride +// for testing (uses the mock bridge script instead of the embedded one). +func loadModel(ctx context.Context, modelPath, scriptOverride string, opts ...inference.LoadOption) (inference.TextModel, error) { + cfg := inference.ApplyLoadOpts(opts) + _ = cfg // reserved for future use (context length, etc.) + + var pyScript string + if scriptOverride != "" { + pyScript = scriptOverride + } else { + var err error + pyScript, err = extractScript() + if err != nil { + return nil, err + } + } + + cmd := exec.CommandContext(ctx, "python3", "-u", pyScript) + cmd.Stderr = nil // let stderr go to parent for debugging + + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return nil, fmt.Errorf("mlxlm: stdin pipe: %w", err) + } + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return nil, fmt.Errorf("mlxlm: stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("mlxlm: start python3: %w", err) + } + + scanner := bufio.NewScanner(stdoutPipe) + scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB line buffer + + m := &mlxlmModel{ + cmd: cmd, + stdin: stdinPipe, + stdout: scanner, + raw: stdoutPipe, + } + + // Send load command. + loadReq := map[string]any{ + "cmd": "load", + "path": modelPath, + } + if err := m.send(loadReq); err != nil { + m.kill() + return nil, fmt.Errorf("mlxlm: send load: %w", err) + } + + resp, err := m.recv() + if err != nil { + m.kill() + return nil, fmt.Errorf("mlxlm: recv load response: %w", err) + } + + if errMsg, ok := resp["error"].(string); ok { + m.kill() + return nil, fmt.Errorf("mlxlm: %s", errMsg) + } + + if modelType, ok := resp["model_type"].(string); ok { + m.modelType = modelType + } + if vocabSize, ok := resp["vocab_size"].(float64); ok { + m.vocabSize = int(vocabSize) + } + + return m, nil +} + +// mlxlmModel is a subprocess-backed TextModel. +type mlxlmModel struct { + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Scanner + raw io.ReadCloser + + modelType string + vocabSize int + + lastErr error + mu sync.Mutex // serialise Generate/Chat calls +} + +// send writes a JSON object as a single line to the subprocess stdin. +func (m *mlxlmModel) send(obj map[string]any) error { + data, err := json.Marshal(obj) + if err != nil { + return err + } + data = append(data, '\n') + _, err = m.stdin.Write(data) + return err +} + +// recv reads and parses a single JSON line from the subprocess stdout. +func (m *mlxlmModel) recv() (map[string]any, error) { + if !m.stdout.Scan() { + if err := m.stdout.Err(); err != nil { + return nil, fmt.Errorf("mlxlm: scanner: %w", err) + } + return nil, errors.New("mlxlm: subprocess closed stdout") + } + var obj map[string]any + if err := json.Unmarshal(m.stdout.Bytes(), &obj); err != nil { + return nil, fmt.Errorf("mlxlm: parse response: %w", err) + } + return obj, nil +} + +// Generate streams tokens from the subprocess for the given prompt. +// Only one Generate or Chat call runs at a time per model (mu lock). +func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + cfg := inference.ApplyGenerateOpts(opts) + + return func(yield func(inference.Token) bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.lastErr = nil + + req := map[string]any{ + "cmd": "generate", + "prompt": prompt, + "max_tokens": cfg.MaxTokens, + } + if cfg.Temperature > 0 { + req["temperature"] = cfg.Temperature + } + if cfg.TopK > 0 { + req["top_k"] = cfg.TopK + } + if cfg.TopP > 0 { + req["top_p"] = cfg.TopP + } + + if err := m.send(req); err != nil { + m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err) + return + } + + for { + select { + case <-ctx.Done(): + m.lastErr = ctx.Err() + // Tell subprocess to stop. + _ = m.send(map[string]any{"cmd": "cancel"}) + // Drain until done or error. + m.drain() + return + default: + } + + resp, err := m.recv() + if err != nil { + m.lastErr = err + return + } + + if errMsg, ok := resp["error"].(string); ok { + m.lastErr = errors.New(errMsg) + return + } + + if _, ok := resp["done"]; ok { + return + } + + text, _ := resp["token"].(string) + var id int32 + if fid, ok := resp["token_id"].(float64); ok { + id = int32(fid) + } + + if !yield(inference.Token{ID: id, Text: text}) { + // Consumer stopped early — send cancel and drain. + _ = m.send(map[string]any{"cmd": "cancel"}) + m.drain() + return + } + } + } +} + +// Chat streams tokens from a multi-turn conversation. +func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + cfg := inference.ApplyGenerateOpts(opts) + + return func(yield func(inference.Token) bool) { + m.mu.Lock() + defer m.mu.Unlock() + m.lastErr = nil + + // Convert messages to JSON-safe format. + msgs := make([]map[string]string, len(messages)) + for i, msg := range messages { + msgs[i] = map[string]string{ + "role": msg.Role, + "content": msg.Content, + } + } + + req := map[string]any{ + "cmd": "chat", + "messages": msgs, + "max_tokens": cfg.MaxTokens, + } + if cfg.Temperature > 0 { + req["temperature"] = cfg.Temperature + } + if cfg.TopK > 0 { + req["top_k"] = cfg.TopK + } + if cfg.TopP > 0 { + req["top_p"] = cfg.TopP + } + + if err := m.send(req); err != nil { + m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err) + return + } + + for { + select { + case <-ctx.Done(): + m.lastErr = ctx.Err() + _ = m.send(map[string]any{"cmd": "cancel"}) + m.drain() + return + default: + } + + resp, err := m.recv() + if err != nil { + m.lastErr = err + return + } + + if errMsg, ok := resp["error"].(string); ok { + m.lastErr = errors.New(errMsg) + return + } + + if _, ok := resp["done"]; ok { + return + } + + text, _ := resp["token"].(string) + var id int32 + if fid, ok := resp["token_id"].(float64); ok { + id = int32(fid) + } + + if !yield(inference.Token{ID: id, Text: text}) { + _ = m.send(map[string]any{"cmd": "cancel"}) + m.drain() + return + } + } + } +} + +// Classify is not supported by the subprocess backend. +// Returns an error indicating that classification requires the native Metal backend. +func (m *mlxlmModel) Classify(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, errors.New("mlxlm: Classify not supported (use native Metal backend)") +} + +// BatchGenerate is not supported by the subprocess backend. +// Returns an error indicating that batch generation requires the native Metal backend. +func (m *mlxlmModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, errors.New("mlxlm: BatchGenerate not supported (use native Metal backend)") +} + +// ModelType returns the architecture identifier reported by the subprocess. +func (m *mlxlmModel) ModelType() string { return m.modelType } + +// Info returns metadata about the loaded model. +func (m *mlxlmModel) Info() inference.ModelInfo { + m.mu.Lock() + defer m.mu.Unlock() + + if err := m.send(map[string]any{"cmd": "info"}); err != nil { + return inference.ModelInfo{} + } + resp, err := m.recv() + if err != nil { + return inference.ModelInfo{} + } + if _, ok := resp["error"]; ok { + return inference.ModelInfo{} + } + + info := inference.ModelInfo{ + Architecture: m.modelType, + VocabSize: m.vocabSize, + } + if layers, ok := resp["layers"].(float64); ok { + info.NumLayers = int(layers) + } + if hidden, ok := resp["hidden_size"].(float64); ok { + info.HiddenSize = int(hidden) + } + return info +} + +// Metrics returns empty metrics — the subprocess backend does not track timing. +func (m *mlxlmModel) Metrics() inference.GenerateMetrics { + return inference.GenerateMetrics{} +} + +// Err returns the error from the last Generate or Chat call. +func (m *mlxlmModel) Err() error { return m.lastErr } + +// Close sends a quit command and waits for the subprocess to exit. +// If the subprocess does not exit within 2 seconds, it is killed. +func (m *mlxlmModel) Close() error { + // Send quit — ignore errors (subprocess may already be dead). + _ = m.send(map[string]any{"cmd": "quit"}) + _ = m.stdin.Close() + + // Wait with timeout. + done := make(chan error, 1) + go func() { done <- m.cmd.Wait() }() + + select { + case err := <-done: + return err + case <-time.After(2 * time.Second): + _ = m.cmd.Process.Kill() + return <-done + } +} + +// drain reads and discards subprocess output until a "done" or "error" line, +// or the scanner stops. Used after cancellation to keep the protocol in sync. +func (m *mlxlmModel) drain() { + for { + resp, err := m.recv() + if err != nil { + return + } + if _, ok := resp["done"]; ok { + return + } + if _, ok := resp["error"]; ok { + return + } + } +} + +// kill terminates the subprocess immediately. Used during load failures. +func (m *mlxlmModel) kill() { + _ = m.stdin.Close() + if m.cmd.Process != nil { + _ = m.cmd.Process.Kill() + } + _ = m.cmd.Wait() +} diff --git a/mlxlm/backend_test.go b/mlxlm/backend_test.go new file mode 100644 index 0000000..2b915dd --- /dev/null +++ b/mlxlm/backend_test.go @@ -0,0 +1,284 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +//go:build !nomlxlm + +package mlxlm + +import ( + "context" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + + "forge.lthn.ai/core/go-inference" +) + +// mockScript returns the absolute path to testdata/mock_bridge.py. +func mockScript(t *testing.T) string { + t.Helper() + _, file, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("cannot determine test file path") + } + return filepath.Join(filepath.Dir(file), "testdata", "mock_bridge.py") +} + +// loadMock spawns a model backed by the mock Python script. +func loadMock(t *testing.T, modelPath string) inference.TextModel { + t.Helper() + m, err := loadModel(context.Background(), modelPath, mockScript(t)) + if err != nil { + t.Fatalf("loadModel: %v", err) + } + t.Cleanup(func() { m.Close() }) + return m +} + +// (a) Name returns "mlx_lm". +func TestBackendName(t *testing.T) { + b := &mlxlmBackend{} + if got := b.Name(); got != "mlx_lm" { + t.Errorf("Name() = %q, want %q", got, "mlx_lm") + } +} + +// (b) LoadModel spawns subprocess, sends load command, gets response. +func TestLoadModel_Good(t *testing.T) { + m := loadMock(t, "/fake/model/path") + if m.ModelType() != "mock_model" { + t.Errorf("ModelType() = %q, want %q", m.ModelType(), "mock_model") + } +} + +// (c) Generate streams tokens from subprocess, all tokens received. +func TestGenerate_Good(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx := context.Background() + var tokens []inference.Token + for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) { + tokens = append(tokens, tok) + } + if err := m.Err(); err != nil { + t.Fatalf("Err() = %v", err) + } + if len(tokens) != 5 { + t.Fatalf("got %d tokens, want 5", len(tokens)) + } + + // Verify token content matches mock. + expected := []string{"Hello", " ", "world", "!", "\n"} + for i, tok := range tokens { + if tok.Text != expected[i] { + t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i]) + } + wantID := int32(100 + i) + if tok.ID != wantID { + t.Errorf("token[%d].ID = %d, want %d", i, tok.ID, wantID) + } + } +} + +// (d) Generate with context cancellation stops early. +func TestGenerate_Cancel(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var count int + for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) { + count++ + if count >= 2 { + cancel() + } + } + // We should have received at most a few tokens before cancellation took effect. + if count > 5 { + t.Errorf("expected early stop, got %d tokens", count) + } + if err := m.Err(); err != context.Canceled { + t.Logf("Err() = %v (expected context.Canceled)", err) + } +} + +// (e) Chat formats messages correctly and streams tokens. +func TestChat_Good(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx := context.Background() + var tokens []inference.Token + for tok := range m.Chat(ctx, []inference.Message{ + {Role: "user", Content: "Hi there"}, + }, inference.WithMaxTokens(5)) { + tokens = append(tokens, tok) + } + if err := m.Err(); err != nil { + t.Fatalf("Err() = %v", err) + } + if len(tokens) != 5 { + t.Fatalf("got %d tokens, want 5", len(tokens)) + } + + // Mock chat returns "I heard you". + expected := []string{"I", " ", "heard", " ", "you"} + for i, tok := range tokens { + if tok.Text != expected[i] { + t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i]) + } + } +} + +// (f) Close kills subprocess cleanly. +func TestClose_Good(t *testing.T) { + m, err := loadModel(context.Background(), "/fake/model/path", mockScript(t)) + if err != nil { + t.Fatalf("loadModel: %v", err) + } + // Close should not error. + if err := m.Close(); err != nil { + t.Errorf("Close() = %v", err) + } +} + +// (g) Err returns error on subprocess failure. +func TestGenerate_Error(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx := context.Background() + var count int + for range m.Generate(ctx, "ERROR trigger", inference.WithMaxTokens(5)) { + count++ + } + if count != 0 { + t.Errorf("expected 0 tokens on error, got %d", count) + } + if err := m.Err(); err == nil { + t.Fatal("expected non-nil Err()") + } else if !strings.Contains(err.Error(), "simulated model error") { + t.Errorf("Err() = %q, want to contain %q", err.Error(), "simulated model error") + } +} + +// (h) LoadModel with invalid path returns error. +func TestLoadModel_Bad(t *testing.T) { + _, err := loadModel(context.Background(), "/path/with/FAIL/in/it", mockScript(t)) + if err == nil { + t.Fatal("expected error for FAIL path") + } + if !strings.Contains(err.Error(), "cannot open model") { + t.Errorf("error = %q, want to contain %q", err.Error(), "cannot open model") + } +} + +// (i) Backend auto-registers (check inference.Get("mlx_lm")). +func TestAutoRegister(t *testing.T) { + b, ok := inference.Get("mlx_lm") + if !ok { + t.Fatal("mlx_lm backend not registered") + } + if b.Name() != "mlx_lm" { + t.Errorf("Name() = %q, want %q", b.Name(), "mlx_lm") + } +} + +// (j) Concurrent Generate calls are serialised (mu lock). +func TestGenerate_Concurrent(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx := context.Background() + const goroutines = 3 + var wg sync.WaitGroup + wg.Add(goroutines) + + results := make([]int, goroutines) + for i := range goroutines { + go func(idx int) { + defer wg.Done() + var count int + for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) { + count++ + } + results[idx] = count + }(i) + } + wg.Wait() + + // Each goroutine should have received all 5 tokens (serialised execution). + for i, count := range results { + if count != 5 { + t.Errorf("goroutine %d got %d tokens, want 5", i, count) + } + } +} + +// Additional: Classify returns unsupported error. +func TestClassify_Unsupported(t *testing.T) { + m := loadMock(t, "/fake/model/path") + _, err := m.Classify(context.Background(), []string{"test"}) + if err == nil { + t.Fatal("expected error from Classify") + } + if !strings.Contains(err.Error(), "not supported") { + t.Errorf("error = %q, want to contain %q", err.Error(), "not supported") + } +} + +// Additional: BatchGenerate returns unsupported error. +func TestBatchGenerate_Unsupported(t *testing.T) { + m := loadMock(t, "/fake/model/path") + _, err := m.BatchGenerate(context.Background(), []string{"test"}) + if err == nil { + t.Fatal("expected error from BatchGenerate") + } + if !strings.Contains(err.Error(), "not supported") { + t.Errorf("error = %q, want to contain %q", err.Error(), "not supported") + } +} + +// Additional: Info returns model metadata. +func TestInfo_Good(t *testing.T) { + m := loadMock(t, "/fake/model/path") + info := m.Info() + if info.Architecture != "mock_model" { + t.Errorf("Architecture = %q, want %q", info.Architecture, "mock_model") + } + if info.VocabSize != 32000 { + t.Errorf("VocabSize = %d, want %d", info.VocabSize, 32000) + } + if info.NumLayers != 24 { + t.Errorf("NumLayers = %d, want %d", info.NumLayers, 24) + } + if info.HiddenSize != 2048 { + t.Errorf("HiddenSize = %d, want %d", info.HiddenSize, 2048) + } +} + +// Additional: Metrics returns zero values (not tracked by subprocess). +func TestMetrics_Zero(t *testing.T) { + m := loadMock(t, "/fake/model/path") + met := m.Metrics() + if met.PromptTokens != 0 || met.GeneratedTokens != 0 { + t.Errorf("expected zero metrics, got prompt=%d generated=%d", + met.PromptTokens, met.GeneratedTokens) + } +} + +// Additional: Generate with fewer max_tokens than available tokens. +func TestGenerate_MaxTokens(t *testing.T) { + m := loadMock(t, "/fake/model/path") + + ctx := context.Background() + var count int + for range m.Generate(ctx, "Hello", inference.WithMaxTokens(3)) { + count++ + } + if err := m.Err(); err != nil { + t.Fatalf("Err() = %v", err) + } + if count != 3 { + t.Errorf("got %d tokens, want 3", count) + } +} diff --git a/mlxlm/bridge.py b/mlxlm/bridge.py new file mode 100644 index 0000000..cbb601d --- /dev/null +++ b/mlxlm/bridge.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 +""" +bridge.py — JSON Lines bridge between Go subprocess and mlx_lm. + +Reads JSON commands from stdin, writes JSON responses to stdout. +Each line is one JSON object. Flushes after every write (critical for streaming). + +Commands: + load — Load model + tokeniser from path + generate — Stream tokens for a prompt + chat — Stream tokens for a multi-turn conversation + info — Return model metadata + cancel — Interrupt current generation (no-op outside generation) + quit — Exit cleanly + +Requires: mlx-lm (pip install mlx-lm) + +SPDX-Licence-Identifier: EUPL-1.2 +""" + +import json +import sys + +_model = None +_tokeniser = None +_model_type = None +_vocab_size = 0 +_cancelled = False + + +def _write(obj): + """Write a JSON line to stdout and flush.""" + sys.stdout.write(json.dumps(obj) + "\n") + sys.stdout.flush() + + +def _error(msg): + """Write an error response.""" + _write({"error": str(msg)}) + + +def handle_load(req): + global _model, _tokeniser, _model_type, _vocab_size + + path = req.get("path", "") + if not path: + _error("load: missing 'path'") + return + + try: + import mlx_lm + _model, _tokeniser = mlx_lm.load(path) + except Exception as e: + _error(f"load: {e}") + return + + # Detect model type from config if available. + _model_type = getattr(_model, "model_type", "unknown") + _vocab_size = getattr(_tokeniser, "vocab_size", 0) + + _write({ + "ok": True, + "model_type": _model_type, + "vocab_size": _vocab_size, + }) + + +def handle_generate(req): + global _cancelled + + if _model is None or _tokeniser is None: + _error("generate: no model loaded") + return + + prompt = req.get("prompt", "") + max_tokens = req.get("max_tokens", 256) + temperature = req.get("temperature", 0.0) + top_p = req.get("top_p", 1.0) + + _cancelled = False + + try: + import mlx_lm + + kwargs = { + "max_tokens": max_tokens, + } + if temperature > 0: + kwargs["temp"] = temperature + if top_p < 1.0: + kwargs["top_p"] = top_p + + count = 0 + for response in mlx_lm.stream_generate( + _model, _tokeniser, prompt=prompt, **kwargs + ): + if _cancelled: + break + text = response.text if hasattr(response, "text") else str(response) + token_id = response.token if hasattr(response, "token") else 0 + _write({"token": text, "token_id": int(token_id)}) + count += 1 + + _write({"done": True, "tokens_generated": count}) + + except Exception as e: + _error(f"generate: {e}") + + +def handle_chat(req): + global _cancelled + + if _model is None or _tokeniser is None: + _error("chat: no model loaded") + return + + messages = req.get("messages", []) + max_tokens = req.get("max_tokens", 256) + temperature = req.get("temperature", 0.0) + top_p = req.get("top_p", 1.0) + + _cancelled = False + + try: + import mlx_lm + + # Apply chat template via tokeniser. + if hasattr(_tokeniser, "apply_chat_template"): + prompt = _tokeniser.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + else: + # Fallback: concatenate messages. + prompt = "\n".join( + f"{m.get('role', 'user')}: {m.get('content', '')}" + for m in messages + ) + + kwargs = { + "max_tokens": max_tokens, + } + if temperature > 0: + kwargs["temp"] = temperature + if top_p < 1.0: + kwargs["top_p"] = top_p + + count = 0 + for response in mlx_lm.stream_generate( + _model, _tokeniser, prompt=prompt, **kwargs + ): + if _cancelled: + break + text = response.text if hasattr(response, "text") else str(response) + token_id = response.token if hasattr(response, "token") else 0 + _write({"token": text, "token_id": int(token_id)}) + count += 1 + + _write({"done": True, "tokens_generated": count}) + + except Exception as e: + _error(f"chat: {e}") + + +def handle_info(_req): + if _model is None: + _error("info: no model loaded") + return + + num_layers = 0 + hidden_size = 0 + if hasattr(_model, "config"): + cfg = _model.config + num_layers = getattr(cfg, "num_hidden_layers", 0) + hidden_size = getattr(cfg, "hidden_size", 0) + + _write({ + "model_type": _model_type or "unknown", + "vocab_size": _vocab_size, + "layers": num_layers, + "hidden_size": hidden_size, + }) + + +def handle_cancel(_req): + global _cancelled + _cancelled = True + + +def main(): + handlers = { + "load": handle_load, + "generate": handle_generate, + "chat": handle_chat, + "info": handle_info, + "cancel": handle_cancel, + "quit": None, + } + + for line in sys.stdin: + line = line.strip() + if not line: + continue + + try: + req = json.loads(line) + except json.JSONDecodeError as e: + _error(f"parse error: {e}") + continue + + cmd = req.get("cmd", "") + + if cmd == "quit": + break + + handler = handlers.get(cmd) + if handler is None: + _error(f"unknown command: {cmd}") + continue + + handler(req) + + +if __name__ == "__main__": + main() diff --git a/mlxlm/testdata/mock_bridge.py b/mlxlm/testdata/mock_bridge.py new file mode 100644 index 0000000..048ca8b --- /dev/null +++ b/mlxlm/testdata/mock_bridge.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +""" +mock_bridge.py — Mock bridge for testing the mlxlm Go backend. + +Implements the same JSON Lines protocol as bridge.py but without mlx_lm. +Returns deterministic fake responses for testing. + +SPDX-Licence-Identifier: EUPL-1.2 +""" + +import json +import sys +import os + +_loaded = False +_model_path = "" + + +def _write(obj): + sys.stdout.write(json.dumps(obj) + "\n") + sys.stdout.flush() + + +def _error(msg): + _write({"error": str(msg)}) + + +def main(): + global _loaded, _model_path + + for line in sys.stdin: + line = line.strip() + if not line: + continue + + try: + req = json.loads(line) + except json.JSONDecodeError as e: + _error(f"parse error: {e}") + continue + + cmd = req.get("cmd", "") + + if cmd == "quit": + break + + elif cmd == "load": + path = req.get("path", "") + if not path: + _error("load: missing 'path'") + continue + # Simulate failure for paths containing "FAIL". + if "FAIL" in path: + _error(f"load: cannot open model at {path}") + continue + _loaded = True + _model_path = path + _write({ + "ok": True, + "model_type": "mock_model", + "vocab_size": 32000, + }) + + elif cmd == "generate": + if not _loaded: + _error("generate: no model loaded") + continue + + max_tokens = req.get("max_tokens", 5) + # Check for error trigger. + prompt = req.get("prompt", "") + if "ERROR" in prompt: + _error("generate: simulated model error") + continue + + # Emit fixed tokens. + tokens = ["Hello", " ", "world", "!", "\n"] + count = min(max_tokens, len(tokens)) + for i in range(count): + _write({"token": tokens[i], "token_id": 100 + i}) + _write({"done": True, "tokens_generated": count}) + + elif cmd == "chat": + if not _loaded: + _error("chat: no model loaded") + continue + + messages = req.get("messages", []) + max_tokens = req.get("max_tokens", 5) + + # Emit tokens reflecting the last user message. + tokens = ["I", " ", "heard", " ", "you"] + count = min(max_tokens, len(tokens)) + for i in range(count): + _write({"token": tokens[i], "token_id": 200 + i}) + _write({"done": True, "tokens_generated": count}) + + elif cmd == "info": + if not _loaded: + _error("info: no model loaded") + continue + _write({ + "model_type": "mock_model", + "vocab_size": 32000, + "layers": 24, + "hidden_size": 2048, + }) + + elif cmd == "cancel": + # No-op in mock — real bridge sets a flag. + pass + + else: + _error(f"unknown command: {cmd}") + + +if __name__ == "__main__": + main()