go-mlx/mlxlm/backend.go
Snider 757a241f59 feat(mlxlm): Phase 5.5 — subprocess backend using Python mlx-lm
Implements inference.Backend via a Python subprocess communicating
over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec.

- bridge.py: embedded Python script wrapping mlx_lm.load() and
  mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit
  commands. Flushes stdout after every JSON line for streaming.

- backend.go: Go subprocess manager. Extracts bridge.py from
  go:embed to temp file, spawns python3, pipes JSON requests.
  mlxlmModel implements full TextModel interface with mutex-
  serialised Generate/Chat, context cancellation with drain,
  and 2-second graceful Close with kill fallback.
  Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm.

- backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed):
  name, load, generate, cancel, chat, close, error propagation,
  invalid path, auto-register, concurrent serialisation, classify/
  batch unsupported, info, metrics, max_tokens limiting.

All tests pass with -race. go vet clean.

Co-Authored-By: Virgil <virgil@lethean.io>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 09:02:30 +00:00

459 lines
11 KiB
Go

// SPDX-Licence-Identifier: EUPL-1.2
//go:build !nomlxlm
// Package mlxlm provides a subprocess-based inference backend using Python's mlx-lm.
//
// It implements the [inference.Backend] interface by spawning a Python process
// that communicates over JSON Lines (one JSON object per line on stdin/stdout).
// This allows using mlx-lm models without CGO or native Metal bindings.
//
// The backend auto-registers as "mlx_lm" via init(). Consumers can opt out
// with the build tag "nomlxlm".
//
// # Usage
//
// import _ "forge.lthn.ai/core/go-mlx/mlxlm"
//
// m, err := inference.LoadModel("/path/to/model", inference.WithBackend("mlx_lm"))
// defer m.Close()
//
// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(64)) {
// fmt.Print(tok.Text)
// }
package mlxlm
import (
"bufio"
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"iter"
"os"
"os/exec"
"path/filepath"
"sync"
"time"
"forge.lthn.ai/core/go-inference"
)
//go:embed bridge.py
var bridgeFS embed.FS
// scriptPath holds the extracted bridge.py temp path. Created once per process.
var (
scriptOnce sync.Once
scriptPath string
scriptErr error
)
// extractScript writes the embedded bridge.py to a temp file and returns its path.
func extractScript() (string, error) {
scriptOnce.Do(func() {
data, err := bridgeFS.ReadFile("bridge.py")
if err != nil {
scriptErr = fmt.Errorf("mlxlm: read embedded bridge.py: %w", err)
return
}
dir, err := os.MkdirTemp("", "mlxlm-*")
if err != nil {
scriptErr = fmt.Errorf("mlxlm: create temp dir: %w", err)
return
}
p := filepath.Join(dir, "bridge.py")
if err := os.WriteFile(p, data, 0o644); err != nil {
scriptErr = fmt.Errorf("mlxlm: write bridge.py: %w", err)
return
}
scriptPath = p
})
return scriptPath, scriptErr
}
func init() {
inference.Register(&mlxlmBackend{})
}
// mlxlmBackend implements inference.Backend for mlx-lm via subprocess.
type mlxlmBackend struct{}
func (b *mlxlmBackend) Name() string { return "mlx_lm" }
// Available reports whether python3 is found on PATH.
func (b *mlxlmBackend) Available() bool {
_, err := exec.LookPath("python3")
return err == nil
}
// LoadModel spawns a Python subprocess running bridge.py, loads the model,
// and returns a TextModel backed by the subprocess.
func (b *mlxlmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
return loadModel(context.Background(), path, "", opts...)
}
// loadModel is the internal implementation, accepting an optional scriptOverride
// for testing (uses the mock bridge script instead of the embedded one).
func loadModel(ctx context.Context, modelPath, scriptOverride string, opts ...inference.LoadOption) (inference.TextModel, error) {
cfg := inference.ApplyLoadOpts(opts)
_ = cfg // reserved for future use (context length, etc.)
var pyScript string
if scriptOverride != "" {
pyScript = scriptOverride
} else {
var err error
pyScript, err = extractScript()
if err != nil {
return nil, err
}
}
cmd := exec.CommandContext(ctx, "python3", "-u", pyScript)
cmd.Stderr = nil // let stderr go to parent for debugging
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return nil, fmt.Errorf("mlxlm: stdin pipe: %w", err)
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("mlxlm: stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("mlxlm: start python3: %w", err)
}
scanner := bufio.NewScanner(stdoutPipe)
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB line buffer
m := &mlxlmModel{
cmd: cmd,
stdin: stdinPipe,
stdout: scanner,
raw: stdoutPipe,
}
// Send load command.
loadReq := map[string]any{
"cmd": "load",
"path": modelPath,
}
if err := m.send(loadReq); err != nil {
m.kill()
return nil, fmt.Errorf("mlxlm: send load: %w", err)
}
resp, err := m.recv()
if err != nil {
m.kill()
return nil, fmt.Errorf("mlxlm: recv load response: %w", err)
}
if errMsg, ok := resp["error"].(string); ok {
m.kill()
return nil, fmt.Errorf("mlxlm: %s", errMsg)
}
if modelType, ok := resp["model_type"].(string); ok {
m.modelType = modelType
}
if vocabSize, ok := resp["vocab_size"].(float64); ok {
m.vocabSize = int(vocabSize)
}
return m, nil
}
// mlxlmModel is a subprocess-backed TextModel.
type mlxlmModel struct {
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Scanner
raw io.ReadCloser
modelType string
vocabSize int
lastErr error
mu sync.Mutex // serialise Generate/Chat calls
}
// send writes a JSON object as a single line to the subprocess stdin.
func (m *mlxlmModel) send(obj map[string]any) error {
data, err := json.Marshal(obj)
if err != nil {
return err
}
data = append(data, '\n')
_, err = m.stdin.Write(data)
return err
}
// recv reads and parses a single JSON line from the subprocess stdout.
func (m *mlxlmModel) recv() (map[string]any, error) {
if !m.stdout.Scan() {
if err := m.stdout.Err(); err != nil {
return nil, fmt.Errorf("mlxlm: scanner: %w", err)
}
return nil, errors.New("mlxlm: subprocess closed stdout")
}
var obj map[string]any
if err := json.Unmarshal(m.stdout.Bytes(), &obj); err != nil {
return nil, fmt.Errorf("mlxlm: parse response: %w", err)
}
return obj, nil
}
// Generate streams tokens from the subprocess for the given prompt.
// Only one Generate or Chat call runs at a time per model (mu lock).
func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
return func(yield func(inference.Token) bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastErr = nil
req := map[string]any{
"cmd": "generate",
"prompt": prompt,
"max_tokens": cfg.MaxTokens,
}
if cfg.Temperature > 0 {
req["temperature"] = cfg.Temperature
}
if cfg.TopK > 0 {
req["top_k"] = cfg.TopK
}
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err)
return
}
for {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
// Tell subprocess to stop.
_ = m.send(map[string]any{"cmd": "cancel"})
// Drain until done or error.
m.drain()
return
default:
}
resp, err := m.recv()
if err != nil {
m.lastErr = err
return
}
if errMsg, ok := resp["error"].(string); ok {
m.lastErr = errors.New(errMsg)
return
}
if _, ok := resp["done"]; ok {
return
}
text, _ := resp["token"].(string)
var id int32
if fid, ok := resp["token_id"].(float64); ok {
id = int32(fid)
}
if !yield(inference.Token{ID: id, Text: text}) {
// Consumer stopped early — send cancel and drain.
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
}
}
}
}
// Chat streams tokens from a multi-turn conversation.
func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
cfg := inference.ApplyGenerateOpts(opts)
return func(yield func(inference.Token) bool) {
m.mu.Lock()
defer m.mu.Unlock()
m.lastErr = nil
// Convert messages to JSON-safe format.
msgs := make([]map[string]string, len(messages))
for i, msg := range messages {
msgs[i] = map[string]string{
"role": msg.Role,
"content": msg.Content,
}
}
req := map[string]any{
"cmd": "chat",
"messages": msgs,
"max_tokens": cfg.MaxTokens,
}
if cfg.Temperature > 0 {
req["temperature"] = cfg.Temperature
}
if cfg.TopK > 0 {
req["top_k"] = cfg.TopK
}
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err)
return
}
for {
select {
case <-ctx.Done():
m.lastErr = ctx.Err()
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
default:
}
resp, err := m.recv()
if err != nil {
m.lastErr = err
return
}
if errMsg, ok := resp["error"].(string); ok {
m.lastErr = errors.New(errMsg)
return
}
if _, ok := resp["done"]; ok {
return
}
text, _ := resp["token"].(string)
var id int32
if fid, ok := resp["token_id"].(float64); ok {
id = int32(fid)
}
if !yield(inference.Token{ID: id, Text: text}) {
_ = m.send(map[string]any{"cmd": "cancel"})
m.drain()
return
}
}
}
}
// Classify is not supported by the subprocess backend.
// Returns an error indicating that classification requires the native Metal backend.
func (m *mlxlmModel) Classify(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
return nil, errors.New("mlxlm: Classify not supported (use native Metal backend)")
}
// BatchGenerate is not supported by the subprocess backend.
// Returns an error indicating that batch generation requires the native Metal backend.
func (m *mlxlmModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
return nil, errors.New("mlxlm: BatchGenerate not supported (use native Metal backend)")
}
// ModelType returns the architecture identifier reported by the subprocess.
func (m *mlxlmModel) ModelType() string { return m.modelType }
// Info returns metadata about the loaded model.
func (m *mlxlmModel) Info() inference.ModelInfo {
m.mu.Lock()
defer m.mu.Unlock()
if err := m.send(map[string]any{"cmd": "info"}); err != nil {
return inference.ModelInfo{}
}
resp, err := m.recv()
if err != nil {
return inference.ModelInfo{}
}
if _, ok := resp["error"]; ok {
return inference.ModelInfo{}
}
info := inference.ModelInfo{
Architecture: m.modelType,
VocabSize: m.vocabSize,
}
if layers, ok := resp["layers"].(float64); ok {
info.NumLayers = int(layers)
}
if hidden, ok := resp["hidden_size"].(float64); ok {
info.HiddenSize = int(hidden)
}
return info
}
// Metrics returns empty metrics — the subprocess backend does not track timing.
func (m *mlxlmModel) Metrics() inference.GenerateMetrics {
return inference.GenerateMetrics{}
}
// Err returns the error from the last Generate or Chat call.
func (m *mlxlmModel) Err() error { return m.lastErr }
// Close sends a quit command and waits for the subprocess to exit.
// If the subprocess does not exit within 2 seconds, it is killed.
func (m *mlxlmModel) Close() error {
// Send quit — ignore errors (subprocess may already be dead).
_ = m.send(map[string]any{"cmd": "quit"})
_ = m.stdin.Close()
// Wait with timeout.
done := make(chan error, 1)
go func() { done <- m.cmd.Wait() }()
select {
case err := <-done:
return err
case <-time.After(2 * time.Second):
_ = m.cmd.Process.Kill()
return <-done
}
}
// drain reads and discards subprocess output until a "done" or "error" line,
// or the scanner stops. Used after cancellation to keep the protocol in sync.
func (m *mlxlmModel) drain() {
for {
resp, err := m.recv()
if err != nil {
return
}
if _, ok := resp["done"]; ok {
return
}
if _, ok := resp["error"]; ok {
return
}
}
}
// kill terminates the subprocess immediately. Used during load failures.
func (m *mlxlmModel) kill() {
_ = m.stdin.Close()
if m.cmd.Process != nil {
_ = m.cmd.Process.Kill()
}
_ = m.cmd.Wait()
}