feat(mlxlm): Phase 5.5 — subprocess backend using Python mlx-lm
Implements inference.Backend via a Python subprocess communicating over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec. - bridge.py: embedded Python script wrapping mlx_lm.load() and mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit commands. Flushes stdout after every JSON line for streaming. - backend.go: Go subprocess manager. Extracts bridge.py from go:embed to temp file, spawns python3, pipes JSON requests. mlxlmModel implements full TextModel interface with mutex- serialised Generate/Chat, context cancellation with drain, and 2-second graceful Close with kill fallback. Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm. - backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed): name, load, generate, cancel, chat, close, error propagation, invalid path, auto-register, concurrent serialisation, classify/ batch unsupported, info, metrics, max_tokens limiting. All tests pass with -race. go vet clean. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
887c221974
commit
757a241f59
5 changed files with 1089 additions and 4 deletions
8
TODO.md
8
TODO.md
|
|
@ -55,7 +55,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
|
|||
|
||||
#### 5.5.1 Python Helper Script (`mlxlm/bridge.py`)
|
||||
|
||||
- [ ] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`:
|
||||
- [x] **Create `mlxlm/bridge.py`** — Thin Python script embedded in Go binary via `//go:embed`:
|
||||
- Reads JSON Lines from stdin, writes JSON Lines to stdout
|
||||
- **`load` command**: `{"cmd":"load","path":"/path/to/model","max_tokens":128}` → loads model + tokenizer, responds `{"ok":true,"model_type":"gemma3","vocab_size":262144}`
|
||||
- **`generate` command**: `{"cmd":"generate","prompt":"Hello","max_tokens":128,"temperature":0.7,"top_k":40,"top_p":0.9}` → streams `{"token":"Hello","token_id":123}` per token, then `{"done":true,"tokens_generated":42}`
|
||||
|
|
@ -68,7 +68,7 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
|
|||
|
||||
#### 5.5.2 Go Backend (`mlxlm/backend.go`)
|
||||
|
||||
- [ ] **Create `mlxlm/backend.go`** — Go subprocess wrapper:
|
||||
- [x] **Create `mlxlm/backend.go`** — Go subprocess wrapper:
|
||||
- `type mlxlmBackend struct{}` — implements `mlx.Backend`
|
||||
- `func init()` — auto-registers via `mlx.Register(&mlxlmBackend{})`. Use build tag `//go:build !nomxlm` so consumers can opt out
|
||||
- `func (b *mlxlmBackend) Name() string` — returns `"mlx_lm"`
|
||||
|
|
@ -95,13 +95,13 @@ Python subprocess wrapper implementing `mlx.Backend`. Uses stdlib `os/exec` (no
|
|||
|
||||
#### 5.5.3 Tests (`mlxlm/backend_test.go`)
|
||||
|
||||
- [ ] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that:
|
||||
- [x] **Create mock Python script for testing** — A minimal Python script (in `testdata/mock_bridge.py` or inline) that:
|
||||
- Responds to `load` with fake model info
|
||||
- Responds to `generate` with 5 fixed tokens then `done`
|
||||
- Responds to `info` with fake metadata
|
||||
- Responds to `chat` same as generate
|
||||
- Tests don't need `mlx_lm` installed — pure mock
|
||||
- [ ] **Tests** using mock script:
|
||||
- [x] **Tests** using mock script:
|
||||
- (a) `Name()` returns `"mlx_lm"`
|
||||
- (b) `LoadModel` spawns subprocess, sends load command, gets response
|
||||
- (c) `Generate` streams tokens from subprocess, all tokens received
|
||||
|
|
|
|||
459
mlxlm/backend.go
Normal file
459
mlxlm/backend.go
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
// SPDX-Licence-Identifier: EUPL-1.2
|
||||
|
||||
//go:build !nomlxlm
|
||||
|
||||
// Package mlxlm provides a subprocess-based inference backend using Python's mlx-lm.
|
||||
//
|
||||
// It implements the [inference.Backend] interface by spawning a Python process
|
||||
// that communicates over JSON Lines (one JSON object per line on stdin/stdout).
|
||||
// This allows using mlx-lm models without CGO or native Metal bindings.
|
||||
//
|
||||
// The backend auto-registers as "mlx_lm" via init(). Consumers can opt out
|
||||
// with the build tag "nomlxlm".
|
||||
//
|
||||
// # Usage
|
||||
//
|
||||
// import _ "forge.lthn.ai/core/go-mlx/mlxlm"
|
||||
//
|
||||
// m, err := inference.LoadModel("/path/to/model", inference.WithBackend("mlx_lm"))
|
||||
// defer m.Close()
|
||||
//
|
||||
// for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(64)) {
|
||||
// fmt.Print(tok.Text)
|
||||
// }
|
||||
package mlxlm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
//go:embed bridge.py
|
||||
var bridgeFS embed.FS
|
||||
|
||||
// scriptPath holds the extracted bridge.py temp path. Created once per process.
|
||||
var (
|
||||
scriptOnce sync.Once
|
||||
scriptPath string
|
||||
scriptErr error
|
||||
)
|
||||
|
||||
// extractScript writes the embedded bridge.py to a temp file and returns its path.
|
||||
func extractScript() (string, error) {
|
||||
scriptOnce.Do(func() {
|
||||
data, err := bridgeFS.ReadFile("bridge.py")
|
||||
if err != nil {
|
||||
scriptErr = fmt.Errorf("mlxlm: read embedded bridge.py: %w", err)
|
||||
return
|
||||
}
|
||||
dir, err := os.MkdirTemp("", "mlxlm-*")
|
||||
if err != nil {
|
||||
scriptErr = fmt.Errorf("mlxlm: create temp dir: %w", err)
|
||||
return
|
||||
}
|
||||
p := filepath.Join(dir, "bridge.py")
|
||||
if err := os.WriteFile(p, data, 0o644); err != nil {
|
||||
scriptErr = fmt.Errorf("mlxlm: write bridge.py: %w", err)
|
||||
return
|
||||
}
|
||||
scriptPath = p
|
||||
})
|
||||
return scriptPath, scriptErr
|
||||
}
|
||||
|
||||
func init() {
|
||||
inference.Register(&mlxlmBackend{})
|
||||
}
|
||||
|
||||
// mlxlmBackend implements inference.Backend for mlx-lm via subprocess.
|
||||
type mlxlmBackend struct{}
|
||||
|
||||
func (b *mlxlmBackend) Name() string { return "mlx_lm" }
|
||||
|
||||
// Available reports whether python3 is found on PATH.
|
||||
func (b *mlxlmBackend) Available() bool {
|
||||
_, err := exec.LookPath("python3")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LoadModel spawns a Python subprocess running bridge.py, loads the model,
|
||||
// and returns a TextModel backed by the subprocess.
|
||||
func (b *mlxlmBackend) LoadModel(path string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
return loadModel(context.Background(), path, "", opts...)
|
||||
}
|
||||
|
||||
// loadModel is the internal implementation, accepting an optional scriptOverride
|
||||
// for testing (uses the mock bridge script instead of the embedded one).
|
||||
func loadModel(ctx context.Context, modelPath, scriptOverride string, opts ...inference.LoadOption) (inference.TextModel, error) {
|
||||
cfg := inference.ApplyLoadOpts(opts)
|
||||
_ = cfg // reserved for future use (context length, etc.)
|
||||
|
||||
var pyScript string
|
||||
if scriptOverride != "" {
|
||||
pyScript = scriptOverride
|
||||
} else {
|
||||
var err error
|
||||
pyScript, err = extractScript()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "python3", "-u", pyScript)
|
||||
cmd.Stderr = nil // let stderr go to parent for debugging
|
||||
|
||||
stdinPipe, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mlxlm: stdin pipe: %w", err)
|
||||
}
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mlxlm: stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("mlxlm: start python3: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stdoutPipe)
|
||||
scanner.Buffer(make([]byte, 0, 1024*1024), 1024*1024) // 1MB line buffer
|
||||
|
||||
m := &mlxlmModel{
|
||||
cmd: cmd,
|
||||
stdin: stdinPipe,
|
||||
stdout: scanner,
|
||||
raw: stdoutPipe,
|
||||
}
|
||||
|
||||
// Send load command.
|
||||
loadReq := map[string]any{
|
||||
"cmd": "load",
|
||||
"path": modelPath,
|
||||
}
|
||||
if err := m.send(loadReq); err != nil {
|
||||
m.kill()
|
||||
return nil, fmt.Errorf("mlxlm: send load: %w", err)
|
||||
}
|
||||
|
||||
resp, err := m.recv()
|
||||
if err != nil {
|
||||
m.kill()
|
||||
return nil, fmt.Errorf("mlxlm: recv load response: %w", err)
|
||||
}
|
||||
|
||||
if errMsg, ok := resp["error"].(string); ok {
|
||||
m.kill()
|
||||
return nil, fmt.Errorf("mlxlm: %s", errMsg)
|
||||
}
|
||||
|
||||
if modelType, ok := resp["model_type"].(string); ok {
|
||||
m.modelType = modelType
|
||||
}
|
||||
if vocabSize, ok := resp["vocab_size"].(float64); ok {
|
||||
m.vocabSize = int(vocabSize)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// mlxlmModel is a subprocess-backed TextModel.
|
||||
type mlxlmModel struct {
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Scanner
|
||||
raw io.ReadCloser
|
||||
|
||||
modelType string
|
||||
vocabSize int
|
||||
|
||||
lastErr error
|
||||
mu sync.Mutex // serialise Generate/Chat calls
|
||||
}
|
||||
|
||||
// send writes a JSON object as a single line to the subprocess stdin.
|
||||
func (m *mlxlmModel) send(obj map[string]any) error {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data = append(data, '\n')
|
||||
_, err = m.stdin.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// recv reads and parses a single JSON line from the subprocess stdout.
|
||||
func (m *mlxlmModel) recv() (map[string]any, error) {
|
||||
if !m.stdout.Scan() {
|
||||
if err := m.stdout.Err(); err != nil {
|
||||
return nil, fmt.Errorf("mlxlm: scanner: %w", err)
|
||||
}
|
||||
return nil, errors.New("mlxlm: subprocess closed stdout")
|
||||
}
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal(m.stdout.Bytes(), &obj); err != nil {
|
||||
return nil, fmt.Errorf("mlxlm: parse response: %w", err)
|
||||
}
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
// Generate streams tokens from the subprocess for the given prompt.
|
||||
// Only one Generate or Chat call runs at a time per model (mu lock).
|
||||
func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastErr = nil
|
||||
|
||||
req := map[string]any{
|
||||
"cmd": "generate",
|
||||
"prompt": prompt,
|
||||
"max_tokens": cfg.MaxTokens,
|
||||
}
|
||||
if cfg.Temperature > 0 {
|
||||
req["temperature"] = cfg.Temperature
|
||||
}
|
||||
if cfg.TopK > 0 {
|
||||
req["top_k"] = cfg.TopK
|
||||
}
|
||||
if cfg.TopP > 0 {
|
||||
req["top_p"] = cfg.TopP
|
||||
}
|
||||
|
||||
if err := m.send(req); err != nil {
|
||||
m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.lastErr = ctx.Err()
|
||||
// Tell subprocess to stop.
|
||||
_ = m.send(map[string]any{"cmd": "cancel"})
|
||||
// Drain until done or error.
|
||||
m.drain()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
resp, err := m.recv()
|
||||
if err != nil {
|
||||
m.lastErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if errMsg, ok := resp["error"].(string); ok {
|
||||
m.lastErr = errors.New(errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := resp["done"]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
text, _ := resp["token"].(string)
|
||||
var id int32
|
||||
if fid, ok := resp["token_id"].(float64); ok {
|
||||
id = int32(fid)
|
||||
}
|
||||
|
||||
if !yield(inference.Token{ID: id, Text: text}) {
|
||||
// Consumer stopped early — send cancel and drain.
|
||||
_ = m.send(map[string]any{"cmd": "cancel"})
|
||||
m.drain()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Chat streams tokens from a multi-turn conversation.
|
||||
func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] {
|
||||
cfg := inference.ApplyGenerateOpts(opts)
|
||||
|
||||
return func(yield func(inference.Token) bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.lastErr = nil
|
||||
|
||||
// Convert messages to JSON-safe format.
|
||||
msgs := make([]map[string]string, len(messages))
|
||||
for i, msg := range messages {
|
||||
msgs[i] = map[string]string{
|
||||
"role": msg.Role,
|
||||
"content": msg.Content,
|
||||
}
|
||||
}
|
||||
|
||||
req := map[string]any{
|
||||
"cmd": "chat",
|
||||
"messages": msgs,
|
||||
"max_tokens": cfg.MaxTokens,
|
||||
}
|
||||
if cfg.Temperature > 0 {
|
||||
req["temperature"] = cfg.Temperature
|
||||
}
|
||||
if cfg.TopK > 0 {
|
||||
req["top_k"] = cfg.TopK
|
||||
}
|
||||
if cfg.TopP > 0 {
|
||||
req["top_p"] = cfg.TopP
|
||||
}
|
||||
|
||||
if err := m.send(req); err != nil {
|
||||
m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.lastErr = ctx.Err()
|
||||
_ = m.send(map[string]any{"cmd": "cancel"})
|
||||
m.drain()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
resp, err := m.recv()
|
||||
if err != nil {
|
||||
m.lastErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if errMsg, ok := resp["error"].(string); ok {
|
||||
m.lastErr = errors.New(errMsg)
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := resp["done"]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
text, _ := resp["token"].(string)
|
||||
var id int32
|
||||
if fid, ok := resp["token_id"].(float64); ok {
|
||||
id = int32(fid)
|
||||
}
|
||||
|
||||
if !yield(inference.Token{ID: id, Text: text}) {
|
||||
_ = m.send(map[string]any{"cmd": "cancel"})
|
||||
m.drain()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Classify is not supported by the subprocess backend.
|
||||
// Returns an error indicating that classification requires the native Metal backend.
|
||||
func (m *mlxlmModel) Classify(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) {
|
||||
return nil, errors.New("mlxlm: Classify not supported (use native Metal backend)")
|
||||
}
|
||||
|
||||
// BatchGenerate is not supported by the subprocess backend.
|
||||
// Returns an error indicating that batch generation requires the native Metal backend.
|
||||
func (m *mlxlmModel) BatchGenerate(_ context.Context, _ []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) {
|
||||
return nil, errors.New("mlxlm: BatchGenerate not supported (use native Metal backend)")
|
||||
}
|
||||
|
||||
// ModelType returns the architecture identifier reported by the subprocess.
|
||||
func (m *mlxlmModel) ModelType() string { return m.modelType }
|
||||
|
||||
// Info returns metadata about the loaded model.
|
||||
func (m *mlxlmModel) Info() inference.ModelInfo {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if err := m.send(map[string]any{"cmd": "info"}); err != nil {
|
||||
return inference.ModelInfo{}
|
||||
}
|
||||
resp, err := m.recv()
|
||||
if err != nil {
|
||||
return inference.ModelInfo{}
|
||||
}
|
||||
if _, ok := resp["error"]; ok {
|
||||
return inference.ModelInfo{}
|
||||
}
|
||||
|
||||
info := inference.ModelInfo{
|
||||
Architecture: m.modelType,
|
||||
VocabSize: m.vocabSize,
|
||||
}
|
||||
if layers, ok := resp["layers"].(float64); ok {
|
||||
info.NumLayers = int(layers)
|
||||
}
|
||||
if hidden, ok := resp["hidden_size"].(float64); ok {
|
||||
info.HiddenSize = int(hidden)
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// Metrics returns empty metrics — the subprocess backend does not track timing.
|
||||
func (m *mlxlmModel) Metrics() inference.GenerateMetrics {
|
||||
return inference.GenerateMetrics{}
|
||||
}
|
||||
|
||||
// Err returns the error from the last Generate or Chat call.
|
||||
func (m *mlxlmModel) Err() error { return m.lastErr }
|
||||
|
||||
// Close sends a quit command and waits for the subprocess to exit.
|
||||
// If the subprocess does not exit within 2 seconds, it is killed.
|
||||
func (m *mlxlmModel) Close() error {
|
||||
// Send quit — ignore errors (subprocess may already be dead).
|
||||
_ = m.send(map[string]any{"cmd": "quit"})
|
||||
_ = m.stdin.Close()
|
||||
|
||||
// Wait with timeout.
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- m.cmd.Wait() }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-time.After(2 * time.Second):
|
||||
_ = m.cmd.Process.Kill()
|
||||
return <-done
|
||||
}
|
||||
}
|
||||
|
||||
// drain reads and discards subprocess output until a "done" or "error" line,
|
||||
// or the scanner stops. Used after cancellation to keep the protocol in sync.
|
||||
func (m *mlxlmModel) drain() {
|
||||
for {
|
||||
resp, err := m.recv()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if _, ok := resp["done"]; ok {
|
||||
return
|
||||
}
|
||||
if _, ok := resp["error"]; ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// kill terminates the subprocess immediately. Used during load failures.
|
||||
func (m *mlxlmModel) kill() {
|
||||
_ = m.stdin.Close()
|
||||
if m.cmd.Process != nil {
|
||||
_ = m.cmd.Process.Kill()
|
||||
}
|
||||
_ = m.cmd.Wait()
|
||||
}
|
||||
284
mlxlm/backend_test.go
Normal file
284
mlxlm/backend_test.go
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
// SPDX-Licence-Identifier: EUPL-1.2
|
||||
|
||||
//go:build !nomlxlm
|
||||
|
||||
package mlxlm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"forge.lthn.ai/core/go-inference"
|
||||
)
|
||||
|
||||
// mockScript returns the absolute path to testdata/mock_bridge.py.
|
||||
func mockScript(t *testing.T) string {
|
||||
t.Helper()
|
||||
_, file, _, ok := runtime.Caller(0)
|
||||
if !ok {
|
||||
t.Fatal("cannot determine test file path")
|
||||
}
|
||||
return filepath.Join(filepath.Dir(file), "testdata", "mock_bridge.py")
|
||||
}
|
||||
|
||||
// loadMock spawns a model backed by the mock Python script.
|
||||
func loadMock(t *testing.T, modelPath string) inference.TextModel {
|
||||
t.Helper()
|
||||
m, err := loadModel(context.Background(), modelPath, mockScript(t))
|
||||
if err != nil {
|
||||
t.Fatalf("loadModel: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { m.Close() })
|
||||
return m
|
||||
}
|
||||
|
||||
// (a) Name returns "mlx_lm".
|
||||
func TestBackendName(t *testing.T) {
|
||||
b := &mlxlmBackend{}
|
||||
if got := b.Name(); got != "mlx_lm" {
|
||||
t.Errorf("Name() = %q, want %q", got, "mlx_lm")
|
||||
}
|
||||
}
|
||||
|
||||
// (b) LoadModel spawns subprocess, sends load command, gets response.
|
||||
func TestLoadModel_Good(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
if m.ModelType() != "mock_model" {
|
||||
t.Errorf("ModelType() = %q, want %q", m.ModelType(), "mock_model")
|
||||
}
|
||||
}
|
||||
|
||||
// (c) Generate streams tokens from subprocess, all tokens received.
|
||||
func TestGenerate_Good(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx := context.Background()
|
||||
var tokens []inference.Token
|
||||
for tok := range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Err() = %v", err)
|
||||
}
|
||||
if len(tokens) != 5 {
|
||||
t.Fatalf("got %d tokens, want 5", len(tokens))
|
||||
}
|
||||
|
||||
// Verify token content matches mock.
|
||||
expected := []string{"Hello", " ", "world", "!", "\n"}
|
||||
for i, tok := range tokens {
|
||||
if tok.Text != expected[i] {
|
||||
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
|
||||
}
|
||||
wantID := int32(100 + i)
|
||||
if tok.ID != wantID {
|
||||
t.Errorf("token[%d].ID = %d, want %d", i, tok.ID, wantID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (d) Generate with context cancellation stops early.
|
||||
func TestGenerate_Cancel(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
var count int
|
||||
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
||||
count++
|
||||
if count >= 2 {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
// We should have received at most a few tokens before cancellation took effect.
|
||||
if count > 5 {
|
||||
t.Errorf("expected early stop, got %d tokens", count)
|
||||
}
|
||||
if err := m.Err(); err != context.Canceled {
|
||||
t.Logf("Err() = %v (expected context.Canceled)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// (e) Chat formats messages correctly and streams tokens.
|
||||
func TestChat_Good(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx := context.Background()
|
||||
var tokens []inference.Token
|
||||
for tok := range m.Chat(ctx, []inference.Message{
|
||||
{Role: "user", Content: "Hi there"},
|
||||
}, inference.WithMaxTokens(5)) {
|
||||
tokens = append(tokens, tok)
|
||||
}
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Err() = %v", err)
|
||||
}
|
||||
if len(tokens) != 5 {
|
||||
t.Fatalf("got %d tokens, want 5", len(tokens))
|
||||
}
|
||||
|
||||
// Mock chat returns "I heard you".
|
||||
expected := []string{"I", " ", "heard", " ", "you"}
|
||||
for i, tok := range tokens {
|
||||
if tok.Text != expected[i] {
|
||||
t.Errorf("token[%d].Text = %q, want %q", i, tok.Text, expected[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// (f) Close kills subprocess cleanly.
|
||||
func TestClose_Good(t *testing.T) {
|
||||
m, err := loadModel(context.Background(), "/fake/model/path", mockScript(t))
|
||||
if err != nil {
|
||||
t.Fatalf("loadModel: %v", err)
|
||||
}
|
||||
// Close should not error.
|
||||
if err := m.Close(); err != nil {
|
||||
t.Errorf("Close() = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// (g) Err returns error on subprocess failure.
|
||||
func TestGenerate_Error(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
for range m.Generate(ctx, "ERROR trigger", inference.WithMaxTokens(5)) {
|
||||
count++
|
||||
}
|
||||
if count != 0 {
|
||||
t.Errorf("expected 0 tokens on error, got %d", count)
|
||||
}
|
||||
if err := m.Err(); err == nil {
|
||||
t.Fatal("expected non-nil Err()")
|
||||
} else if !strings.Contains(err.Error(), "simulated model error") {
|
||||
t.Errorf("Err() = %q, want to contain %q", err.Error(), "simulated model error")
|
||||
}
|
||||
}
|
||||
|
||||
// (h) LoadModel with invalid path returns error.
|
||||
func TestLoadModel_Bad(t *testing.T) {
|
||||
_, err := loadModel(context.Background(), "/path/with/FAIL/in/it", mockScript(t))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for FAIL path")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "cannot open model") {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), "cannot open model")
|
||||
}
|
||||
}
|
||||
|
||||
// (i) Backend auto-registers (check inference.Get("mlx_lm")).
|
||||
func TestAutoRegister(t *testing.T) {
|
||||
b, ok := inference.Get("mlx_lm")
|
||||
if !ok {
|
||||
t.Fatal("mlx_lm backend not registered")
|
||||
}
|
||||
if b.Name() != "mlx_lm" {
|
||||
t.Errorf("Name() = %q, want %q", b.Name(), "mlx_lm")
|
||||
}
|
||||
}
|
||||
|
||||
// (j) Concurrent Generate calls are serialised (mu lock).
|
||||
func TestGenerate_Concurrent(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx := context.Background()
|
||||
const goroutines = 3
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(goroutines)
|
||||
|
||||
results := make([]int, goroutines)
|
||||
for i := range goroutines {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
var count int
|
||||
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(5)) {
|
||||
count++
|
||||
}
|
||||
results[idx] = count
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Each goroutine should have received all 5 tokens (serialised execution).
|
||||
for i, count := range results {
|
||||
if count != 5 {
|
||||
t.Errorf("goroutine %d got %d tokens, want 5", i, count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Additional: Classify returns unsupported error.
|
||||
func TestClassify_Unsupported(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
_, err := m.Classify(context.Background(), []string{"test"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from Classify")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not supported") {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Additional: BatchGenerate returns unsupported error.
|
||||
func TestBatchGenerate_Unsupported(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
_, err := m.BatchGenerate(context.Background(), []string{"test"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error from BatchGenerate")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "not supported") {
|
||||
t.Errorf("error = %q, want to contain %q", err.Error(), "not supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Additional: Info returns model metadata.
|
||||
func TestInfo_Good(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
info := m.Info()
|
||||
if info.Architecture != "mock_model" {
|
||||
t.Errorf("Architecture = %q, want %q", info.Architecture, "mock_model")
|
||||
}
|
||||
if info.VocabSize != 32000 {
|
||||
t.Errorf("VocabSize = %d, want %d", info.VocabSize, 32000)
|
||||
}
|
||||
if info.NumLayers != 24 {
|
||||
t.Errorf("NumLayers = %d, want %d", info.NumLayers, 24)
|
||||
}
|
||||
if info.HiddenSize != 2048 {
|
||||
t.Errorf("HiddenSize = %d, want %d", info.HiddenSize, 2048)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional: Metrics returns zero values (not tracked by subprocess).
|
||||
func TestMetrics_Zero(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
met := m.Metrics()
|
||||
if met.PromptTokens != 0 || met.GeneratedTokens != 0 {
|
||||
t.Errorf("expected zero metrics, got prompt=%d generated=%d",
|
||||
met.PromptTokens, met.GeneratedTokens)
|
||||
}
|
||||
}
|
||||
|
||||
// Additional: Generate with fewer max_tokens than available tokens.
|
||||
func TestGenerate_MaxTokens(t *testing.T) {
|
||||
m := loadMock(t, "/fake/model/path")
|
||||
|
||||
ctx := context.Background()
|
||||
var count int
|
||||
for range m.Generate(ctx, "Hello", inference.WithMaxTokens(3)) {
|
||||
count++
|
||||
}
|
||||
if err := m.Err(); err != nil {
|
||||
t.Fatalf("Err() = %v", err)
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("got %d tokens, want 3", count)
|
||||
}
|
||||
}
|
||||
224
mlxlm/bridge.py
Normal file
224
mlxlm/bridge.py
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
bridge.py — JSON Lines bridge between Go subprocess and mlx_lm.
|
||||
|
||||
Reads JSON commands from stdin, writes JSON responses to stdout.
|
||||
Each line is one JSON object. Flushes after every write (critical for streaming).
|
||||
|
||||
Commands:
|
||||
load — Load model + tokeniser from path
|
||||
generate — Stream tokens for a prompt
|
||||
chat — Stream tokens for a multi-turn conversation
|
||||
info — Return model metadata
|
||||
cancel — Interrupt current generation (no-op outside generation)
|
||||
quit — Exit cleanly
|
||||
|
||||
Requires: mlx-lm (pip install mlx-lm)
|
||||
|
||||
SPDX-Licence-Identifier: EUPL-1.2
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
_model = None
|
||||
_tokeniser = None
|
||||
_model_type = None
|
||||
_vocab_size = 0
|
||||
_cancelled = False
|
||||
|
||||
|
||||
def _write(obj):
|
||||
"""Write a JSON line to stdout and flush."""
|
||||
sys.stdout.write(json.dumps(obj) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _error(msg):
|
||||
"""Write an error response."""
|
||||
_write({"error": str(msg)})
|
||||
|
||||
|
||||
def handle_load(req):
|
||||
global _model, _tokeniser, _model_type, _vocab_size
|
||||
|
||||
path = req.get("path", "")
|
||||
if not path:
|
||||
_error("load: missing 'path'")
|
||||
return
|
||||
|
||||
try:
|
||||
import mlx_lm
|
||||
_model, _tokeniser = mlx_lm.load(path)
|
||||
except Exception as e:
|
||||
_error(f"load: {e}")
|
||||
return
|
||||
|
||||
# Detect model type from config if available.
|
||||
_model_type = getattr(_model, "model_type", "unknown")
|
||||
_vocab_size = getattr(_tokeniser, "vocab_size", 0)
|
||||
|
||||
_write({
|
||||
"ok": True,
|
||||
"model_type": _model_type,
|
||||
"vocab_size": _vocab_size,
|
||||
})
|
||||
|
||||
|
||||
def handle_generate(req):
|
||||
global _cancelled
|
||||
|
||||
if _model is None or _tokeniser is None:
|
||||
_error("generate: no model loaded")
|
||||
return
|
||||
|
||||
prompt = req.get("prompt", "")
|
||||
max_tokens = req.get("max_tokens", 256)
|
||||
temperature = req.get("temperature", 0.0)
|
||||
top_p = req.get("top_p", 1.0)
|
||||
|
||||
_cancelled = False
|
||||
|
||||
try:
|
||||
import mlx_lm
|
||||
|
||||
kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if temperature > 0:
|
||||
kwargs["temp"] = temperature
|
||||
if top_p < 1.0:
|
||||
kwargs["top_p"] = top_p
|
||||
|
||||
count = 0
|
||||
for response in mlx_lm.stream_generate(
|
||||
_model, _tokeniser, prompt=prompt, **kwargs
|
||||
):
|
||||
if _cancelled:
|
||||
break
|
||||
text = response.text if hasattr(response, "text") else str(response)
|
||||
token_id = response.token if hasattr(response, "token") else 0
|
||||
_write({"token": text, "token_id": int(token_id)})
|
||||
count += 1
|
||||
|
||||
_write({"done": True, "tokens_generated": count})
|
||||
|
||||
except Exception as e:
|
||||
_error(f"generate: {e}")
|
||||
|
||||
|
||||
def handle_chat(req):
|
||||
global _cancelled
|
||||
|
||||
if _model is None or _tokeniser is None:
|
||||
_error("chat: no model loaded")
|
||||
return
|
||||
|
||||
messages = req.get("messages", [])
|
||||
max_tokens = req.get("max_tokens", 256)
|
||||
temperature = req.get("temperature", 0.0)
|
||||
top_p = req.get("top_p", 1.0)
|
||||
|
||||
_cancelled = False
|
||||
|
||||
try:
|
||||
import mlx_lm
|
||||
|
||||
# Apply chat template via tokeniser.
|
||||
if hasattr(_tokeniser, "apply_chat_template"):
|
||||
prompt = _tokeniser.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Fallback: concatenate messages.
|
||||
prompt = "\n".join(
|
||||
f"{m.get('role', 'user')}: {m.get('content', '')}"
|
||||
for m in messages
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if temperature > 0:
|
||||
kwargs["temp"] = temperature
|
||||
if top_p < 1.0:
|
||||
kwargs["top_p"] = top_p
|
||||
|
||||
count = 0
|
||||
for response in mlx_lm.stream_generate(
|
||||
_model, _tokeniser, prompt=prompt, **kwargs
|
||||
):
|
||||
if _cancelled:
|
||||
break
|
||||
text = response.text if hasattr(response, "text") else str(response)
|
||||
token_id = response.token if hasattr(response, "token") else 0
|
||||
_write({"token": text, "token_id": int(token_id)})
|
||||
count += 1
|
||||
|
||||
_write({"done": True, "tokens_generated": count})
|
||||
|
||||
except Exception as e:
|
||||
_error(f"chat: {e}")
|
||||
|
||||
|
||||
def handle_info(_req):
|
||||
if _model is None:
|
||||
_error("info: no model loaded")
|
||||
return
|
||||
|
||||
num_layers = 0
|
||||
hidden_size = 0
|
||||
if hasattr(_model, "config"):
|
||||
cfg = _model.config
|
||||
num_layers = getattr(cfg, "num_hidden_layers", 0)
|
||||
hidden_size = getattr(cfg, "hidden_size", 0)
|
||||
|
||||
_write({
|
||||
"model_type": _model_type or "unknown",
|
||||
"vocab_size": _vocab_size,
|
||||
"layers": num_layers,
|
||||
"hidden_size": hidden_size,
|
||||
})
|
||||
|
||||
|
||||
def handle_cancel(_req):
|
||||
global _cancelled
|
||||
_cancelled = True
|
||||
|
||||
|
||||
def main():
|
||||
handlers = {
|
||||
"load": handle_load,
|
||||
"generate": handle_generate,
|
||||
"chat": handle_chat,
|
||||
"info": handle_info,
|
||||
"cancel": handle_cancel,
|
||||
"quit": None,
|
||||
}
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
req = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
_error(f"parse error: {e}")
|
||||
continue
|
||||
|
||||
cmd = req.get("cmd", "")
|
||||
|
||||
if cmd == "quit":
|
||||
break
|
||||
|
||||
handler = handlers.get(cmd)
|
||||
if handler is None:
|
||||
_error(f"unknown command: {cmd}")
|
||||
continue
|
||||
|
||||
handler(req)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
118
mlxlm/testdata/mock_bridge.py
vendored
Normal file
118
mlxlm/testdata/mock_bridge.py
vendored
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
mock_bridge.py — Mock bridge for testing the mlxlm Go backend.
|
||||
|
||||
Implements the same JSON Lines protocol as bridge.py but without mlx_lm.
|
||||
Returns deterministic fake responses for testing.
|
||||
|
||||
SPDX-Licence-Identifier: EUPL-1.2
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
_loaded = False
|
||||
_model_path = ""
|
||||
|
||||
|
||||
def _write(obj):
|
||||
sys.stdout.write(json.dumps(obj) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _error(msg):
|
||||
_write({"error": str(msg)})
|
||||
|
||||
|
||||
def main():
|
||||
global _loaded, _model_path
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
req = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
_error(f"parse error: {e}")
|
||||
continue
|
||||
|
||||
cmd = req.get("cmd", "")
|
||||
|
||||
if cmd == "quit":
|
||||
break
|
||||
|
||||
elif cmd == "load":
|
||||
path = req.get("path", "")
|
||||
if not path:
|
||||
_error("load: missing 'path'")
|
||||
continue
|
||||
# Simulate failure for paths containing "FAIL".
|
||||
if "FAIL" in path:
|
||||
_error(f"load: cannot open model at {path}")
|
||||
continue
|
||||
_loaded = True
|
||||
_model_path = path
|
||||
_write({
|
||||
"ok": True,
|
||||
"model_type": "mock_model",
|
||||
"vocab_size": 32000,
|
||||
})
|
||||
|
||||
elif cmd == "generate":
|
||||
if not _loaded:
|
||||
_error("generate: no model loaded")
|
||||
continue
|
||||
|
||||
max_tokens = req.get("max_tokens", 5)
|
||||
# Check for error trigger.
|
||||
prompt = req.get("prompt", "")
|
||||
if "ERROR" in prompt:
|
||||
_error("generate: simulated model error")
|
||||
continue
|
||||
|
||||
# Emit fixed tokens.
|
||||
tokens = ["Hello", " ", "world", "!", "\n"]
|
||||
count = min(max_tokens, len(tokens))
|
||||
for i in range(count):
|
||||
_write({"token": tokens[i], "token_id": 100 + i})
|
||||
_write({"done": True, "tokens_generated": count})
|
||||
|
||||
elif cmd == "chat":
|
||||
if not _loaded:
|
||||
_error("chat: no model loaded")
|
||||
continue
|
||||
|
||||
messages = req.get("messages", [])
|
||||
max_tokens = req.get("max_tokens", 5)
|
||||
|
||||
# Emit tokens reflecting the last user message.
|
||||
tokens = ["I", " ", "heard", " ", "you"]
|
||||
count = min(max_tokens, len(tokens))
|
||||
for i in range(count):
|
||||
_write({"token": tokens[i], "token_id": 200 + i})
|
||||
_write({"done": True, "tokens_generated": count})
|
||||
|
||||
elif cmd == "info":
|
||||
if not _loaded:
|
||||
_error("info: no model loaded")
|
||||
continue
|
||||
_write({
|
||||
"model_type": "mock_model",
|
||||
"vocab_size": 32000,
|
||||
"layers": 24,
|
||||
"hidden_size": 2048,
|
||||
})
|
||||
|
||||
elif cmd == "cancel":
|
||||
# No-op in mock — real bridge sets a flag.
|
||||
pass
|
||||
|
||||
else:
|
||||
_error(f"unknown command: {cmd}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Reference in a new issue