feat(mlxlm): Phase 5.5 — subprocess backend using Python mlx-lm

Implements inference.Backend via a Python subprocess communicating
over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec.

- bridge.py: embedded Python script wrapping mlx_lm.load() and
  mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit
  commands. Flushes stdout after every JSON line for streaming.

- backend.go: Go subprocess manager. Extracts bridge.py from
  go:embed to temp file, spawns python3, pipes JSON requests.
  mlxlmModel implements full TextModel interface with mutex-
  serialised Generate/Chat, context cancellation with drain,
  and 2-second graceful Close with kill fallback.
  Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm.

- backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed):
  name, load, generate, cancel, chat, close, error propagation,
  invalid path, auto-register, concurrent serialisation, classify/
  batch unsupported, info, metrics, max_tokens limiting.

All tests pass with -race. go vet clean.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-02-20 09:02:30 +00:00
parent 887c221974
commit 757a241f59
5 changed files with 1089 additions and 4 deletions

View file

@ -55,7 +55,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
#### 5.5.1 Python Helper Script (`mlxlm/bridge.py`)
- [ ] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`:
- [x] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`:
- Reads JSON Lines from stdin, writes JSON Lines to stdout
- **`load` command**: `{"cmd":"load","path":"/path/to/model","max_tokens":128}` → loads model + tokenizer, responds `{"ok":true,"model_type":"gemma3","vocab_size":262144}`
- **`generate` command**: `{"cmd":"generate","prompt":"Hello","max_tokens":128,"temperature":0.7,"top_k":40,"top_p":0.9}` → streams `{"token":"Hello","token_id":123}` per token, then `{"done":true,"tokens_generated":42}`
@ -68,7 +68,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
#### 5.5.2 Go Backend (`mlxlm/backend.go`)
- [ ] **Create `mlxlm/backend.go`** — Go subprocess wrapper:
- [x] **Create `mlxlm/backend.go`** — Go subprocess wrapper:
- `type mlxlmBackend struct{}` — implements `mlx.Backend`
- `func init()` — auto-registers via `mlx.Register(&mlxlmBackend{})`. Use build tag `//go:build !nomxlm` so consumers can opt out
- `func (b *mlxlmBackend) Name() string` — returns `"mlx_lm"`
@ -95,13 +95,13 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
#### 5.5.3 Tests (`mlxlm/backend_test.go`)
- [ ] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that:
- [x] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that:
- Responds to `load` with fake model info
- Responds to `generate` with 5 fixed tokens then `done`
- Responds to `info` with fake metadata
- Responds to `chat` same as generate
- Tests don't need `mlx_lm` installed — pure mock
- [ ] **Tests** using mock script:
- [x] **Tests** using mock script:
- (a) `Name()` returns `"mlx_lm"`
- (b) `LoadModel` spawns subprocess, sends load command, gets response
- (c) `Generate` streams tokens from subprocess, all tokens received

459
mlxlm/backend.go Normal file
View file

@ -0,0 +1,459 @@
// SPDX-Licence-Identifier: EUPL-1.2
//go:build !nomlxlm
// Package mlxlm provides a subprocess-based inference backend using Python's mlx-lm.
//
// It implements the [inference.Backend] interface by spawning a Python process
// that communicates over JSON Lines (one JSON object per line on stdin/stdout).
// This allows using mlx-lm models without CGO or native Metal bindings.
//
// The backend auto-registers as "mlx_lm" via init(). Consumers can opt out
// with the build tag "nomlxlm".
//
// # Usage
//
// import _ "forge.lthn.ai/core/go-mlx/mlxlm"
//
// m, err := inference.LoadModel("/path/to/model", inference.WithBackend("mlx_lm"))
// defer m.Close()
//
// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(64)) {
// fmt.Print(tok.Text)
// }
package mlxlm
import (
"bufio"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"os"
"os/exec"
"path/filepath"
"sync"
"time"
"forge.lthn.ai/core/go-inference"
)
//go:embed bridge.py
var bridgeFS embed.FS
// scriptPath holds the extracted bridge.py temp path. Created once per process.
var (
scriptOnce sync.Once
scriptPath string
scriptErr error
)
// extractScript writes the embedded bridge.py to a temp file and returns its path.
func extractScript() (string, error) {
scriptOnce.Do(func() {
data, err := bridgeFS.ReadFile("bridge.py")
if err != nil {
scriptErr = fmt.Errorf("mlxlm: read embedded bridge.py: %w", err)
return
}
dir, err := os.MkdirTemp("", "mlxlm-*")
if err != nil {
scriptErr = fmt.Errorf("mlxlm: create temp dir: %w", err)
return
}
p := filepath.Join(dir, "bridge.py")
if err := os.WriteFile(p, data, 0o644); err != nil {
scriptErr = fmt.Errorf("mlxlm: write bridge.py: %w", err)
return
}
scriptPath = p
})
return scriptPath, scriptErr
}
func init() {
inference.Register(&mlxlmBackend{})
}
// mlxlmBackend implements inference.Backend for mlx-lm via subprocess.
type mlxlmBackend struct{}
func (b *mlxlmBackend) Name() string { return "mlx_lm" }
// Available reports whether python3 is found on PATH.
func (b *mlxlmBackend) Available() bool {
_, err := exec.LookPath("python3")
return err == nil
}
// LoadModel spawns a Python subprocess running bridge.py, loads the model,
// and returns a TextModel backed by the subprocess.
func (b *mlxlmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
return loadModel(context.Background(), path, "", opts...)
}
// loadModel is the internal implementation, accepting an optional scriptOverride
// for testing (uses the mock bridge script instead of the embedded one).
func loadModel(ctx context.Context, modelPath, scriptOverride string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
_ = cfg // reserved for future use (context length, etc.)
var pyScript string
if scriptOverride != "" {
pyScript = scriptOverride
} else {
var err error
pyScript, err = extractScript()
if err != nil {
return nil, err
}
}
cmd := exec.CommandContext(ctx, "python3", "-u", pyScript)
cmd.Stderr = nil // let stderr go to parent for debugging
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("mlxlm: stdin pipe: %w", err)
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("mlxlm: stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("mlxlm: start python3: %w", err)
}
scanner := bufio.NewScanner(stdoutPipe)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB line buffer
m := &mlxlmModel{
cmd: cmd,
stdin: stdinPipe,
stdout: scanner,
raw: stdoutPipe,
}
// Send load command.
loadReq := map[string]any{
"cmd": "load",
"path": modelPath,
}
if err := m.send(loadReq); err != nil {
m.kill()
return nil, fmt.Errorf("mlxlm: send load: %w", err)
}
resp, err := m.recv()
if err != nil {
m.kill()
return nil, fmt.Errorf("mlxlm: recv load response: %w", err)
}
if errMsg, ok := resp["error"].(string); ok {
m.kill()
return nil, fmt.Errorf("mlxlm: %s", errMsg)
}
if modelType, ok := resp["model_type"].(string); ok {
m.modelType = modelType
}
if vocabSize, ok := resp["vocab_size"].(float64); ok {
m.vocabSize = int(vocabSize)
}
return m, nil
}
// mlxlmModel is a subprocess-backed TextModel.
type mlxlmModel struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Scanner
raw io.ReadCloser
modelType string
vocabSize int
lastErr error
mu sync.Mutex // serialise Generate/Chat calls
}
// send writes a JSON object as a single line to the subprocess stdin.
func (m *mlxlmModel) send(obj map[string]any) error {
data, err := json.Marshal(obj)
if err != nil {
return err
}
data = append(data, '\n')
_, err = m.stdin.Write(data)
return err
}
// recv reads and parses a single JSON line from the subprocess stdout.
func (m *mlxlmModel) recv() (map[string]any, error) {
if !m.stdout.Scan() {
if err := m.stdout.Err(); err != nil {
return nil, fmt.Errorf("mlxlm: scanner: %w", err)
}
return nil, errors.New("mlxlm: subprocess closed stdout")
}
var obj map[string]any
if err := json.Unmarshal(m.stdout.Bytes(), &obj); err != nil {
return nil, fmt.Errorf("mlxlm: parse response: %w", err)
}
return obj, nil
}
// Generate streams tokens from the subprocess for the given prompt.
// Only one Generate or Chat call runs at a time per model (mu lock).
func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
return func(yield func(inference.Token) bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastErr = nil
req := map[string]any{
"cmd": "generate",
"prompt": prompt,
"max_tokens": cfg.MaxTokens,
}
if cfg.Temperature > 0 {
req["temperature"] = cfg.Temperature
}
if cfg.TopK > 0 {
req["top_k"] = cfg.TopK
}
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err)
return
}
for {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
// Tell subprocess to stop.
_ = m.send(map[string]any{"cmd": "cancel"})
// Drain until done or error.
m.drain()
return
default:
}
resp, err := m.recv()
if err != nil {
m.lastErr = err
return
}
if errMsg, ok := resp["error"].(string); ok {
m.lastErr = errors.New(errMsg)
return
}
if _, ok := resp["done"]; ok {
return
}
text, _ := resp["token"].(string)
var id int32
if fid, ok := resp["token_id"].(float64); ok {
id = int32(fid)
}
if !yield(inference.Token{ID: id, Text: text}) {
// Consumer stopped early — send cancel and drain.
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
}
}
}
}
// Chat streams tokens from a multi-turn conversation.
func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
return func(yield func(inference.Token) bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastErr = nil
// Convert messages to JSON-safe format.
msgs := make([]map[string]string, len(messages))
for i, msg := range messages {
msgs[i] = map[string]string{
"role": msg.Role,
"content": msg.Content,
}
}
req := map[string]any{
"cmd": "chat",
"messages": msgs,
"max_tokens": cfg.MaxTokens,
}
if cfg.Temperature > 0 {
req["temperature"] = cfg.Temperature
}
if cfg.TopK > 0 {
req["top_k"] = cfg.TopK
}
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err)
return
}
for {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
default:
}
resp, err := m.recv()
if err != nil {
m.lastErr = err
return
}
if errMsg, ok := resp["error"].(string); ok {
m.lastErr = errors.New(errMsg)
return
}
if _, ok := resp["done"]; ok {
return
}
text, _ := resp["token"].(string)
var id int32
if fid, ok := resp["token_id"].(float64); ok {
id = int32(fid)
}
if !yield(inference.Token{ID: id, Text: text}) {
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
}
}
}
}
// Classify is not supported by the subprocess backend.
// Returns an error indicating that classification requires the native Metal backend.
func (m *mlxlmModel) Classify(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
return nil, errors.New("mlxlm: Classify not supported (use native Metal backend)")
}
// BatchGenerate is not supported by the subprocess backend.
// Returns an error indicating that batch generation requires the native Metal backend.
func (m *mlxlmModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
return nil, errors.New("mlxlm: BatchGenerate not supported (use native Metal backend)")
}
// ModelType returns the architecture identifier reported by the subprocess.
func (m *mlxlmModel) ModelType() string { return m.modelType }
// Info returns metadata about the loaded model.
func (m *mlxlmModel) Info() inference.ModelInfo {
m.mu.Lock()
defer m.mu.Unlock()
if err := m.send(map[string]any{"cmd": "info"}); err != nil {
return inference.ModelInfo{}
}
resp, err := m.recv()
if err != nil {
return inference.ModelInfo{}
}
if _, ok := resp["error"]; ok {
return inference.ModelInfo{}
}
info := inference.ModelInfo{
Architecture: m.modelType,
VocabSize: m.vocabSize,
}
if layers, ok := resp["layers"].(float64); ok {
info.NumLayers = int(layers)
}
if hidden, ok := resp["hidden_size"].(float64); ok {
info.HiddenSize = int(hidden)
}
return info
}
// Metrics returns empty metrics — the subprocess backend does not track timing.
func (m *mlxlmModel) Metrics() inference.GenerateMetrics {
return inference.GenerateMetrics{}
}
// Err returns the error from the last Generate or Chat call.
func (m *mlxlmModel) Err() error { return m.lastErr }
// Close sends a quit command and waits for the subprocess to exit.
// If the subprocess does not exit within 2 seconds, it is killed.
func (m *mlxlmModel) Close() error {
// Send quit — ignore errors (subprocess may already be dead).
_ = m.send(map[string]any{"cmd": "quit"})
_ = m.stdin.Close()
// Wait with timeout.
done := make(chan error, 1)
go func() { done <- m.cmd.Wait() }()
select {
case err := <-done:
return err
case <-time.After(2 * time.Second):
_ = m.cmd.Process.Kill()
return <-done
}
}
// drain reads and discards subprocess output until a "done" or "error" line,
// or the scanner stops. Used after cancellation to keep the protocol in sync.
func (m *mlxlmModel) drain() {
for {
resp, err := m.recv()
if err != nil {
return
}
if _, ok := resp["done"]; ok {
return
}
if _, ok := resp["error"]; ok {
return
}
}
}
// kill terminates the subprocess immediately. Used during load failures.
func (m *mlxlmModel) kill() {
_ = m.stdin.Close()
if m.cmd.Process != nil {
_ = m.cmd.Process.Kill()
}
_ = m.cmd.Wait()
}

284
mlxlm/backend_test.go Normal file
View file

@ -0,0 +1,284 @@
// SPDX-Licence-Identifier: EUPL-1.2
//go:build !nomlxlm
package mlxlm
import (
"context"
"path/filepath"
"runtime"
"strings"
"sync"
"testing"
"forge.lthn.ai/core/go-inference"
)
// mockScript returns the absolute path to testdata/mock_bridge.py.
func mockScript(t *testing.T) string {
t.Helper()
_, file, _, ok := runtime.Caller(0)
if !ok {
t.Fatal("cannot determine test file path")
}
return filepath.Join(filepath.Dir(file), "testdata", "mock_bridge.py")
}
// loadMock spawns a model backed by the mock Python script.
func loadMock(t *testing.T, modelPath string) inference.TextModel {
t.Helper()
m, err := loadModel(context.Background(), modelPath, mockScript(t))
if err != nil {
t.Fatalf("loadModel: %v", err)
}
t.Cleanup(func() { m.Close() })
return m
}
// (a) Name returns "mlx_lm".
func TestBackendName(t *testing.T) {
b := &mlxlmBackend{}
if got := b.Name(); got != "mlx_lm" {
t.Errorf("Name() = %q, want %q", got, "mlx_lm")
}
}
// (b) LoadModel spawns subprocess, sends load command, gets response.
func TestLoadModel_Good(t *testing.T) {
m := loadMock(t, "/fake/model/path")
if m.ModelType() != "mock_model" {
t.Errorf("ModelType() = %q, want %q", m.ModelType(), "mock_model")
}
}
// (c) Generate streams tokens from subprocess, all tokens received.
func TestGenerate_Good(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx := context.Background()
var tokens []inference.Token
for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
tokens = append(tokens, tok)
}
if err := m.Err(); err != nil {
t.Fatalf("Err() = %v", err)
}
if len(tokens) != 5 {
t.Fatalf("got %d tokens, want 5", len(tokens))
}
// Verify token content matches mock.
expected := []string{"Hello", " ", "world", "!", "\n"}
for i, tok := range tokens {
if tok.Text != expected[i] {
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
}
wantID := int32(100 + i)
if tok.ID != wantID {
t.Errorf("token[%d].ID = %d, want %d", i, tok.ID, wantID)
}
}
}
// (d) Generate with context cancellation stops early.
func TestGenerate_Cancel(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var count int
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
count++
if count >= 2 {
cancel()
}
}
// We should have received at most a few tokens before cancellation took effect.
if count > 5 {
t.Errorf("expected early stop, got %d tokens", count)
}
if err := m.Err(); err != context.Canceled {
t.Logf("Err() = %v (expected context.Canceled)", err)
}
}
// (e) Chat formats messages correctly and streams tokens.
func TestChat_Good(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx := context.Background()
var tokens []inference.Token
for tok := range m.Chat(ctx, []inference.Message{
{Role: "user", Content: "Hi there"},
}, inference.WithMaxTokens(5)) {
tokens = append(tokens, tok)
}
if err := m.Err(); err != nil {
t.Fatalf("Err() = %v", err)
}
if len(tokens) != 5 {
t.Fatalf("got %d tokens, want 5", len(tokens))
}
// Mock chat returns "I heard you".
expected := []string{"I", " ", "heard", " ", "you"}
for i, tok := range tokens {
if tok.Text != expected[i] {
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
}
}
}
// (f) Close kills subprocess cleanly.
func TestClose_Good(t *testing.T) {
m, err := loadModel(context.Background(), "/fake/model/path", mockScript(t))
if err != nil {
t.Fatalf("loadModel: %v", err)
}
// Close should not error.
if err := m.Close(); err != nil {
t.Errorf("Close() = %v", err)
}
}
// (g) Err returns error on subprocess failure.
func TestGenerate_Error(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx := context.Background()
var count int
for range m.Generate(ctx, "ERROR trigger", inference.WithMaxTokens(5)) {
count++
}
if count != 0 {
t.Errorf("expected 0 tokens on error, got %d", count)
}
if err := m.Err(); err == nil {
t.Fatal("expected non-nil Err()")
} else if !strings.Contains(err.Error(), "simulated model error") {
t.Errorf("Err() = %q, want to contain %q", err.Error(), "simulated model error")
}
}
// (h) LoadModel with invalid path returns error.
func TestLoadModel_Bad(t *testing.T) {
_, err := loadModel(context.Background(), "/path/with/FAIL/in/it", mockScript(t))
if err == nil {
t.Fatal("expected error for FAIL path")
}
if !strings.Contains(err.Error(), "cannot open model") {
t.Errorf("error = %q, want to contain %q", err.Error(), "cannot open model")
}
}
// (i) Backend auto-registers (check inference.Get("mlx_lm")).
func TestAutoRegister(t *testing.T) {
b, ok := inference.Get("mlx_lm")
if !ok {
t.Fatal("mlx_lm backend not registered")
}
if b.Name() != "mlx_lm" {
t.Errorf("Name() = %q, want %q", b.Name(), "mlx_lm")
}
}
// (j) Concurrent Generate calls are serialised (mu lock).
func TestGenerate_Concurrent(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx := context.Background()
const goroutines = 3
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]int, goroutines)
for i := range goroutines {
go func(idx int) {
defer wg.Done()
var count int
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
count++
}
results[idx] = count
}(i)
}
wg.Wait()
// Each goroutine should have received all 5 tokens (serialised execution).
for i, count := range results {
if count != 5 {
t.Errorf("goroutine %d got %d tokens, want 5", i, count)
}
}
}
// Additional: Classify returns unsupported error.
func TestClassify_Unsupported(t *testing.T) {
m := loadMock(t, "/fake/model/path")
_, err := m.Classify(context.Background(), []string{"test"})
if err == nil {
t.Fatal("expected error from Classify")
}
if !strings.Contains(err.Error(), "not supported") {
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
}
}
// Additional: BatchGenerate returns unsupported error.
func TestBatchGenerate_Unsupported(t *testing.T) {
m := loadMock(t, "/fake/model/path")
_, err := m.BatchGenerate(context.Background(), []string{"test"})
if err == nil {
t.Fatal("expected error from BatchGenerate")
}
if !strings.Contains(err.Error(), "not supported") {
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
}
}
// Additional: Info returns model metadata.
func TestInfo_Good(t *testing.T) {
m := loadMock(t, "/fake/model/path")
info := m.Info()
if info.Architecture != "mock_model" {
t.Errorf("Architecture = %q, want %q", info.Architecture, "mock_model")
}
if info.VocabSize != 32000 {
t.Errorf("VocabSize = %d, want %d", info.VocabSize, 32000)
}
if info.NumLayers != 24 {
t.Errorf("NumLayers = %d, want %d", info.NumLayers, 24)
}
if info.HiddenSize != 2048 {
t.Errorf("HiddenSize = %d, want %d", info.HiddenSize, 2048)
}
}
// Additional: Metrics returns zero values (not tracked by subprocess).
func TestMetrics_Zero(t *testing.T) {
m := loadMock(t, "/fake/model/path")
met := m.Metrics()
if met.PromptTokens != 0 || met.GeneratedTokens != 0 {
t.Errorf("expected zero metrics, got prompt=%d generated=%d",
met.PromptTokens, met.GeneratedTokens)
}
}
// Additional: Generate with fewer max_tokens than available tokens.
func TestGenerate_MaxTokens(t *testing.T) {
m := loadMock(t, "/fake/model/path")
ctx := context.Background()
var count int
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(3)) {
count++
}
if err := m.Err(); err != nil {
t.Fatalf("Err() = %v", err)
}
if count != 3 {
t.Errorf("got %d tokens, want 3", count)
}
}

224
mlxlm/bridge.py Normal file
View file

@ -0,0 +1,224 @@
#!/usr/bin/env python3
"""
bridge.py JSON Lines bridge between Go subprocess and mlx_lm.
Reads JSON commands from stdin, writes JSON responses to stdout.
Each line is one JSON object. Flushes after every write (critical for streaming).
Commands:
load Load model + tokeniser from path
generate Stream tokens for a prompt
chat Stream tokens for a multi-turn conversation
info Return model metadata
cancel Interrupt current generation (no-op outside generation)
quit Exit cleanly
Requires: mlx-lm (pip install mlx-lm)
SPDX-Licence-Identifier: EUPL-1.2
"""
import json
import sys
_model = None
_tokeniser = None
_model_type = None
_vocab_size = 0
_cancelled = False
def _write(obj):
"""Write a JSON line to stdout and flush."""
sys.stdout.write(json.dumps(obj) + "\n")
sys.stdout.flush()
def _error(msg):
"""Write an error response."""
_write({"error": str(msg)})
def handle_load(req):
global _model, _tokeniser, _model_type, _vocab_size
path = req.get("path", "")
if not path:
_error("load: missing 'path'")
return
try:
import mlx_lm
_model, _tokeniser = mlx_lm.load(path)
except Exception as e:
_error(f"load: {e}")
return
# Detect model type from config if available.
_model_type = getattr(_model, "model_type", "unknown")
_vocab_size = getattr(_tokeniser, "vocab_size", 0)
_write({
"ok": True,
"model_type": _model_type,
"vocab_size": _vocab_size,
})
def handle_generate(req):
global _cancelled
if _model is None or _tokeniser is None:
_error("generate: no model loaded")
return
prompt = req.get("prompt", "")
max_tokens = req.get("max_tokens", 256)
temperature = req.get("temperature", 0.0)
top_p = req.get("top_p", 1.0)
_cancelled = False
try:
import mlx_lm
kwargs = {
"max_tokens": max_tokens,
}
if temperature > 0:
kwargs["temp"] = temperature
if top_p < 1.0:
kwargs["top_p"] = top_p
count = 0
for response in mlx_lm.stream_generate(
_model, _tokeniser, prompt=prompt, **kwargs
):
if _cancelled:
break
text = response.text if hasattr(response, "text") else str(response)
token_id = response.token if hasattr(response, "token") else 0
_write({"token": text, "token_id": int(token_id)})
count += 1
_write({"done": True, "tokens_generated": count})
except Exception as e:
_error(f"generate: {e}")
def handle_chat(req):
global _cancelled
if _model is None or _tokeniser is None:
_error("chat: no model loaded")
return
messages = req.get("messages", [])
max_tokens = req.get("max_tokens", 256)
temperature = req.get("temperature", 0.0)
top_p = req.get("top_p", 1.0)
_cancelled = False
try:
import mlx_lm
# Apply chat template via tokeniser.
if hasattr(_tokeniser, "apply_chat_template"):
prompt = _tokeniser.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
# Fallback: concatenate messages.
prompt = "\n".join(
f"{m.get('role', 'user')}: {m.get('content', '')}"
for m in messages
)
kwargs = {
"max_tokens": max_tokens,
}
if temperature > 0:
kwargs["temp"] = temperature
if top_p < 1.0:
kwargs["top_p"] = top_p
count = 0
for response in mlx_lm.stream_generate(
_model, _tokeniser, prompt=prompt, **kwargs
):
if _cancelled:
break
text = response.text if hasattr(response, "text") else str(response)
token_id = response.token if hasattr(response, "token") else 0
_write({"token": text, "token_id": int(token_id)})
count += 1
_write({"done": True, "tokens_generated": count})
except Exception as e:
_error(f"chat: {e}")
def handle_info(_req):
if _model is None:
_error("info: no model loaded")
return
num_layers = 0
hidden_size = 0
if hasattr(_model, "config"):
cfg = _model.config
num_layers = getattr(cfg, "num_hidden_layers", 0)
hidden_size = getattr(cfg, "hidden_size", 0)
_write({
"model_type": _model_type or "unknown",
"vocab_size": _vocab_size,
"layers": num_layers,
"hidden_size": hidden_size,
})
def handle_cancel(_req):
global _cancelled
_cancelled = True
def main():
handlers = {
"load": handle_load,
"generate": handle_generate,
"chat": handle_chat,
"info": handle_info,
"cancel": handle_cancel,
"quit": None,
}
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
req = json.loads(line)
except json.JSONDecodeError as e:
_error(f"parse error: {e}")
continue
cmd = req.get("cmd", "")
if cmd == "quit":
break
handler = handlers.get(cmd)
if handler is None:
_error(f"unknown command: {cmd}")
continue
handler(req)
if __name__ == "__main__":
main()

118
mlxlm/testdata/mock_bridge.py vendored Normal file
View file

@ -0,0 +1,118 @@
#!/usr/bin/env python3
"""
mock_bridge.py Mock bridge for testing the mlxlm Go backend.
Implements the same JSON Lines protocol as bridge.py but without mlx_lm.
Returns deterministic fake responses for testing.
SPDX-Licence-Identifier: EUPL-1.2
"""
import json
import sys
import os
_loaded = False
_model_path = ""
def _write(obj):
sys.stdout.write(json.dumps(obj) + "\n")
sys.stdout.flush()
def _error(msg):
_write({"error": str(msg)})
def main():
global _loaded, _model_path
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
req = json.loads(line)
except json.JSONDecodeError as e:
_error(f"parse error: {e}")
continue
cmd = req.get("cmd", "")
if cmd == "quit":
break
elif cmd == "load":
path = req.get("path", "")
if not path:
_error("load: missing 'path'")
continue
# Simulate failure for paths containing "FAIL".
if "FAIL" in path:
_error(f"load: cannot open model at {path}")
continue
_loaded = True
_model_path = path
_write({
"ok": True,
"model_type": "mock_model",
"vocab_size": 32000,
})
elif cmd == "generate":
if not _loaded:
_error("generate: no model loaded")
continue
max_tokens = req.get("max_tokens", 5)
# Check for error trigger.
prompt = req.get("prompt", "")
if "ERROR" in prompt:
_error("generate: simulated model error")
continue
# Emit fixed tokens.
tokens = ["Hello", " ", "world", "!", "\n"]
count = min(max_tokens, len(tokens))
for i in range(count):
_write({"token": tokens[i], "token_id": 100 + i})
_write({"done": True, "tokens_generated": count})
elif cmd == "chat":
if not _loaded:
_error("chat: no model loaded")
continue
messages = req.get("messages", [])
max_tokens = req.get("max_tokens", 5)
# Emit tokens reflecting the last user message.
tokens = ["I", " ", "heard", " ", "you"]
count = min(max_tokens, len(tokens))
for i in range(count):
_write({"token": tokens[i], "token_id": 200 + i})
_write({"done": True, "tokens_generated": count})
elif cmd == "info":
if not _loaded:
_error("info: no model loaded")
continue
_write({
"model_type": "mock_model",
"vocab_size": 32000,
"layers": 24,
"hidden_size": 2048,
})
elif cmd == "cancel":
# No-op in mock — real bridge sets a flag.
pass
else:
_error(f"unknown command: {cmd}")
if __name__ == "__main__":
main()