feat: update mlxlm bridge for mlx-lm 0.30.7 API
All checks were successful
Security Scan / security (push) Successful in 12s
Test / Vet & Build (push) Successful in 55s

- Rewrite bridge.py to use make_sampler() and make_logits_processors()
  instead of deprecated direct kwargs (temp, top_p, top_k)
- Add repeat_penalty forwarding in backend.go for Generate and Chat
- Extract _build_gen_kwargs() helper shared by generate and chat handlers

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-23 18:37:03 +00:00
parent 71fe4bb5ac
commit 802b7660f2
2 changed files with 30 additions and 22 deletions

View file

@ -233,6 +233,9 @@ func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...infere
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if cfg.RepeatPenalty > 1.0 {
req["repeat_penalty"] = cfg.RepeatPenalty
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err)
@ -314,6 +317,9 @@ func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opt
if cfg.TopP > 0 {
req["top_p"] = cfg.TopP
}
if cfg.RepeatPenalty > 1.0 {
req["repeat_penalty"] = cfg.RepeatPenalty
}
if err := m.send(req); err != nil {
m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err)

View file

@ -39,6 +39,28 @@ def _error(msg):
_write({"error": str(msg)})
def _build_gen_kwargs(req):
"""Build sampler and logits_processors kwargs for stream_generate."""
from mlx_lm.sample_utils import make_sampler, make_logits_processors
temperature = req.get("temperature", 0.0)
top_p = req.get("top_p", 0.0)
top_k = req.get("top_k", 0)
repeat_penalty = req.get("repeat_penalty", 0.0)
kwargs = {
"max_tokens": req.get("max_tokens", 256),
"sampler": make_sampler(temp=temperature, top_p=top_p, top_k=top_k),
}
if repeat_penalty > 1.0:
kwargs["logits_processors"] = make_logits_processors(
repetition_penalty=repeat_penalty,
)
return kwargs
def handle_load(req):
global _model, _tokeniser, _model_type, _vocab_size
@ -73,22 +95,12 @@ def handle_generate(req):
return
prompt = req.get("prompt", "")
max_tokens = req.get("max_tokens", 256)
temperature = req.get("temperature", 0.0)
top_p = req.get("top_p", 1.0)
_cancelled = False
try:
import mlx_lm
kwargs = {
"max_tokens": max_tokens,
}
if temperature > 0:
kwargs["temp"] = temperature
if top_p < 1.0:
kwargs["top_p"] = top_p
kwargs = _build_gen_kwargs(req)
count = 0
for response in mlx_lm.stream_generate(
@ -115,10 +127,6 @@ def handle_chat(req):
return
messages = req.get("messages", [])
max_tokens = req.get("max_tokens", 256)
temperature = req.get("temperature", 0.0)
top_p = req.get("top_p", 1.0)
_cancelled = False
try:
@ -136,13 +144,7 @@ def handle_chat(req):
for m in messages
)
kwargs = {
"max_tokens": max_tokens,
}
if temperature > 0:
kwargs["temp"] = temperature
if top_p < 1.0:
kwargs["top_p"] = top_p
kwargs = _build_gen_kwargs(req)
count = 0
for response in mlx_lm.stream_generate(