Implements inference.Backend via a Python subprocess communicating over JSON Lines (stdin/stdout). No CGO required — pure Go + os/exec. - bridge.py: embedded Python script wrapping mlx_lm.load() and mlx_lm.stream_generate() with load/generate/chat/info/cancel/quit commands. Flushes stdout after every JSON line for streaming. - backend.go: Go subprocess manager. Extracts bridge.py from go:embed to temp file, spawns python3, pipes JSON requests. mlxlmModel implements full TextModel interface with mutex- serialised Generate/Chat, context cancellation with drain, and 2-second graceful Close with kill fallback. Auto-registers as "mlx_lm" via init(). Build tag: !nomlxlm. - backend_test.go: 15 tests using mock_bridge.py (no mlx_lm needed): name, load, generate, cancel, chat, close, error propagation, invalid path, auto-register, concurrent serialisation, classify/ batch unsupported, info, metrics, max_tokens limiting. All tests pass with -race. go vet clean. Co-Authored-By: Virgil <virgil@lethean.io> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
224 lines
5.4 KiB
Python
224 lines
5.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
bridge.py — JSON Lines bridge between Go subprocess and mlx_lm.
|
|
|
|
Reads JSON commands from stdin, writes JSON responses to stdout.
|
|
Each line is one JSON object. Flushes after every write (critical for streaming).
|
|
|
|
Commands:
|
|
load — Load model + tokeniser from path
|
|
generate — Stream tokens for a prompt
|
|
chat — Stream tokens for a multi-turn conversation
|
|
info — Return model metadata
|
|
cancel — Interrupt current generation (no-op outside generation)
|
|
quit — Exit cleanly
|
|
|
|
Requires: mlx-lm (pip install mlx-lm)
|
|
|
|
SPDX-Licence-Identifier: EUPL-1.2
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
|
|
_model = None
|
|
_tokeniser = None
|
|
_model_type = None
|
|
_vocab_size = 0
|
|
_cancelled = False
|
|
|
|
|
|
def _write(obj):
|
|
"""Write a JSON line to stdout and flush."""
|
|
sys.stdout.write(json.dumps(obj) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def _error(msg):
|
|
"""Write an error response."""
|
|
_write({"error": str(msg)})
|
|
|
|
|
|
def handle_load(req):
|
|
global _model, _tokeniser, _model_type, _vocab_size
|
|
|
|
path = req.get("path", "")
|
|
if not path:
|
|
_error("load: missing 'path'")
|
|
return
|
|
|
|
try:
|
|
import mlx_lm
|
|
_model, _tokeniser = mlx_lm.load(path)
|
|
except Exception as e:
|
|
_error(f"load: {e}")
|
|
return
|
|
|
|
# Detect model type from config if available.
|
|
_model_type = getattr(_model, "model_type", "unknown")
|
|
_vocab_size = getattr(_tokeniser, "vocab_size", 0)
|
|
|
|
_write({
|
|
"ok": True,
|
|
"model_type": _model_type,
|
|
"vocab_size": _vocab_size,
|
|
})
|
|
|
|
|
|
def handle_generate(req):
|
|
global _cancelled
|
|
|
|
if _model is None or _tokeniser is None:
|
|
_error("generate: no model loaded")
|
|
return
|
|
|
|
prompt = req.get("prompt", "")
|
|
max_tokens = req.get("max_tokens", 256)
|
|
temperature = req.get("temperature", 0.0)
|
|
top_p = req.get("top_p", 1.0)
|
|
|
|
_cancelled = False
|
|
|
|
try:
|
|
import mlx_lm
|
|
|
|
kwargs = {
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if temperature > 0:
|
|
kwargs["temp"] = temperature
|
|
if top_p < 1.0:
|
|
kwargs["top_p"] = top_p
|
|
|
|
count = 0
|
|
for response in mlx_lm.stream_generate(
|
|
_model, _tokeniser, prompt=prompt, **kwargs
|
|
):
|
|
if _cancelled:
|
|
break
|
|
text = response.text if hasattr(response, "text") else str(response)
|
|
token_id = response.token if hasattr(response, "token") else 0
|
|
_write({"token": text, "token_id": int(token_id)})
|
|
count += 1
|
|
|
|
_write({"done": True, "tokens_generated": count})
|
|
|
|
except Exception as e:
|
|
_error(f"generate: {e}")
|
|
|
|
|
|
def handle_chat(req):
|
|
global _cancelled
|
|
|
|
if _model is None or _tokeniser is None:
|
|
_error("chat: no model loaded")
|
|
return
|
|
|
|
messages = req.get("messages", [])
|
|
max_tokens = req.get("max_tokens", 256)
|
|
temperature = req.get("temperature", 0.0)
|
|
top_p = req.get("top_p", 1.0)
|
|
|
|
_cancelled = False
|
|
|
|
try:
|
|
import mlx_lm
|
|
|
|
# Apply chat template via tokeniser.
|
|
if hasattr(_tokeniser, "apply_chat_template"):
|
|
prompt = _tokeniser.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
else:
|
|
# Fallback: concatenate messages.
|
|
prompt = "\n".join(
|
|
f"{m.get('role', 'user')}: {m.get('content', '')}"
|
|
for m in messages
|
|
)
|
|
|
|
kwargs = {
|
|
"max_tokens": max_tokens,
|
|
}
|
|
if temperature > 0:
|
|
kwargs["temp"] = temperature
|
|
if top_p < 1.0:
|
|
kwargs["top_p"] = top_p
|
|
|
|
count = 0
|
|
for response in mlx_lm.stream_generate(
|
|
_model, _tokeniser, prompt=prompt, **kwargs
|
|
):
|
|
if _cancelled:
|
|
break
|
|
text = response.text if hasattr(response, "text") else str(response)
|
|
token_id = response.token if hasattr(response, "token") else 0
|
|
_write({"token": text, "token_id": int(token_id)})
|
|
count += 1
|
|
|
|
_write({"done": True, "tokens_generated": count})
|
|
|
|
except Exception as e:
|
|
_error(f"chat: {e}")
|
|
|
|
|
|
def handle_info(_req):
|
|
if _model is None:
|
|
_error("info: no model loaded")
|
|
return
|
|
|
|
num_layers = 0
|
|
hidden_size = 0
|
|
if hasattr(_model, "config"):
|
|
cfg = _model.config
|
|
num_layers = getattr(cfg, "num_hidden_layers", 0)
|
|
hidden_size = getattr(cfg, "hidden_size", 0)
|
|
|
|
_write({
|
|
"model_type": _model_type or "unknown",
|
|
"vocab_size": _vocab_size,
|
|
"layers": num_layers,
|
|
"hidden_size": hidden_size,
|
|
})
|
|
|
|
|
|
def handle_cancel(_req):
|
|
global _cancelled
|
|
_cancelled = True
|
|
|
|
|
|
def main():
|
|
handlers = {
|
|
"load": handle_load,
|
|
"generate": handle_generate,
|
|
"chat": handle_chat,
|
|
"info": handle_info,
|
|
"cancel": handle_cancel,
|
|
"quit": None,
|
|
}
|
|
|
|
for line in sys.stdin:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
req = json.loads(line)
|
|
except json.JSONDecodeError as e:
|
|
_error(f"parse error: {e}")
|
|
continue
|
|
|
|
cmd = req.get("cmd", "")
|
|
|
|
if cmd == "quit":
|
|
break
|
|
|
|
handler = handlers.get(cmd)
|
|
if handler is None:
|
|
_error(f"unknown command: {cmd}")
|
|
continue
|
|
|
|
handler(req)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|