feat: update mlxlm bridge for mlx-lm 0.30.7 API
- Rewrite bridge.py to use make_sampler() and make_logits_processors() instead of deprecated direct kwargs (temp, top_p, top_k) - Add repeat_penalty forwarding in backend.go for Generate and Chat - Extract _build_gen_kwargs() helper shared by generate and chat handlers Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
71fe4bb5ac
commit
802b7660f2
2 changed files with 30 additions and 22 deletions
|
|
@ -233,6 +233,9 @@ func (m *mlxlmModel) Generate(ctx context.Context, prompt string, opts ...infere
|
|||
if cfg.TopP > 0 {
|
||||
req["top_p"] = cfg.TopP
|
||||
}
|
||||
if cfg.RepeatPenalty > 1.0 {
|
||||
req["repeat_penalty"] = cfg.RepeatPenalty
|
||||
}
|
||||
|
||||
if err := m.send(req); err != nil {
|
||||
m.lastErr = fmt.Errorf("mlxlm: send generate: %w", err)
|
||||
|
|
@ -314,6 +317,9 @@ func (m *mlxlmModel) Chat(ctx context.Context, messages []inference.Message, opt
|
|||
if cfg.TopP > 0 {
|
||||
req["top_p"] = cfg.TopP
|
||||
}
|
||||
if cfg.RepeatPenalty > 1.0 {
|
||||
req["repeat_penalty"] = cfg.RepeatPenalty
|
||||
}
|
||||
|
||||
if err := m.send(req); err != nil {
|
||||
m.lastErr = fmt.Errorf("mlxlm: send chat: %w", err)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,28 @@ def _error(msg):
|
|||
_write({"error": str(msg)})
|
||||
|
||||
|
||||
def _build_gen_kwargs(req):
|
||||
"""Build sampler and logits_processors kwargs for stream_generate."""
|
||||
from mlx_lm.sample_utils import make_sampler, make_logits_processors
|
||||
|
||||
temperature = req.get("temperature", 0.0)
|
||||
top_p = req.get("top_p", 0.0)
|
||||
top_k = req.get("top_k", 0)
|
||||
repeat_penalty = req.get("repeat_penalty", 0.0)
|
||||
|
||||
kwargs = {
|
||||
"max_tokens": req.get("max_tokens", 256),
|
||||
"sampler": make_sampler(temp=temperature, top_p=top_p, top_k=top_k),
|
||||
}
|
||||
|
||||
if repeat_penalty > 1.0:
|
||||
kwargs["logits_processors"] = make_logits_processors(
|
||||
repetition_penalty=repeat_penalty,
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def handle_load(req):
|
||||
global _model, _tokeniser, _model_type, _vocab_size
|
||||
|
||||
|
|
@ -73,22 +95,12 @@ def handle_generate(req):
|
|||
return
|
||||
|
||||
prompt = req.get("prompt", "")
|
||||
max_tokens = req.get("max_tokens", 256)
|
||||
temperature = req.get("temperature", 0.0)
|
||||
top_p = req.get("top_p", 1.0)
|
||||
|
||||
_cancelled = False
|
||||
|
||||
try:
|
||||
import mlx_lm
|
||||
|
||||
kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if temperature > 0:
|
||||
kwargs["temp"] = temperature
|
||||
if top_p < 1.0:
|
||||
kwargs["top_p"] = top_p
|
||||
kwargs = _build_gen_kwargs(req)
|
||||
|
||||
count = 0
|
||||
for response in mlx_lm.stream_generate(
|
||||
|
|
@ -115,10 +127,6 @@ def handle_chat(req):
|
|||
return
|
||||
|
||||
messages = req.get("messages", [])
|
||||
max_tokens = req.get("max_tokens", 256)
|
||||
temperature = req.get("temperature", 0.0)
|
||||
top_p = req.get("top_p", 1.0)
|
||||
|
||||
_cancelled = False
|
||||
|
||||
try:
|
||||
|
|
@ -136,13 +144,7 @@ def handle_chat(req):
|
|||
for m in messages
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if temperature > 0:
|
||||
kwargs["temp"] = temperature
|
||||
if top_p < 1.0:
|
||||
kwargs["top_p"] = top_p
|
||||
kwargs = _build_gen_kwargs(req)
|
||||
|
||||
count = 0
|
||||
for response in mlx_lm.stream_generate(
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue