From 802b7660f25810f0dad71695bf0ab1e91240f649 Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 23 Feb 2026 18:37:03 +0000 Subject: [PATCH] feat: update mlxlm bridge for mlx-lm 0.30.7 API - Rewrite bridge.py to use make_sampler() and make_logits_processors() instead of deprecated direct kwargs (temp, top_p, top_k) - Add repeat_penalty forwarding in backend.go for Generate and Chat - Extract _build_gen_kwargs() helper shared by generate and chat handlers Co-Authored-By: Virgil --- mlxlm/backend.go | 6 ++++++ mlxlm/bridge.py | 46 ++++++++++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/mlxlm/backend.go b/mlxlm/backend.go index a8de7ed..45d21b7 100644 --- a/mlxlm/backend.go +++ b/mlxlm/backend.go @@ -233,6 +233,9 @@ func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...infere if cfg.TopP > 0 { req["top_p"] = cfg.TopP } + if cfg.RepeatPenalty > 1.0 { + req["repeat_penalty"] = cfg.RepeatPenalty + } if err := m.send(req); err != nil { m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err) @@ -314,6 +317,9 @@ func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opt if cfg.TopP > 0 { req["top_p"] = cfg.TopP } + if cfg.RepeatPenalty > 1.0 { + req["repeat_penalty"] = cfg.RepeatPenalty + } if err := m.send(req); err != nil { m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err) diff --git a/mlxlm/bridge.py b/mlxlm/bridge.py index cbb601d..5a93224 100644 --- a/mlxlm/bridge.py +++ b/mlxlm/bridge.py @@ -39,6 +39,28 @@ def _error(msg): _write({"error": str(msg)}) +def _build_gen_kwargs(req): + """Build sampler and logits_processors kwargs for stream_generate.""" + from mlx_lm.sample_utils import make_sampler, make_logits_processors + + temperature = req.get("temperature", 0.0) + top_p = req.get("top_p", 0.0) + top_k = req.get("top_k", 0) + repeat_penalty = req.get("repeat_penalty", 0.0) + + kwargs = { + "max_tokens": req.get("max_tokens", 256), + "sampler": make_sampler(temp=temperature, top_p=top_p, top_k=top_k), + } + + if repeat_penalty > 1.0: + kwargs["logits_processors"] = make_logits_processors( + repetition_penalty=repeat_penalty, + ) + + return kwargs + + def handle_load(req): global _model, _tokeniser, _model_type, _vocab_size @@ -73,22 +95,12 @@ def handle_generate(req): return prompt = req.get("prompt", "") - max_tokens = req.get("max_tokens", 256) - temperature = req.get("temperature", 0.0) - top_p = req.get("top_p", 1.0) - _cancelled = False try: import mlx_lm - kwargs = { - "max_tokens": max_tokens, - } - if temperature > 0: - kwargs["temp"] = temperature - if top_p < 1.0: - kwargs["top_p"] = top_p + kwargs = _build_gen_kwargs(req) count = 0 for response in mlx_lm.stream_generate( @@ -115,10 +127,6 @@ def handle_chat(req): return messages = req.get("messages", []) - max_tokens = req.get("max_tokens", 256) - temperature = req.get("temperature", 0.0) - top_p = req.get("top_p", 1.0) - _cancelled = False try: @@ -136,13 +144,7 @@ def handle_chat(req): for m in messages ) - kwargs = { - "max_tokens": max_tokens, - } - if temperature > 0: - kwargs["temp"] = temperature - if top_p < 1.0: - kwargs["top_p"] = top_p + kwargs = _build_gen_kwargs(req) count = 0 for response in mlx_lm.stream_generate(