1
0
Fork 0
forked from lthn/LEM
LEM/scripts/chat-4b-base.py

54 lines
1.3 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Interactive chat with base Gemma3-4B-IT (no LEM training)."""
import sys
sys.stdout.reconfigure(line_buffering=True)
import mlx.core as mx
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
mx.metal.set_memory_limit(24 * 1024**3)
mx.metal.set_cache_limit(8 * 1024**3)
MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx'
print(f'Loading Gemma3-4B-IT (base)...')
model, tokenizer = load(MODEL_PATH)
_set_infer = getattr(model, 'eval')
_set_infer()
print('Ready.\n')
sampler = make_sampler(temp=0.7)
history = []
while True:
try:
user_input = input('You: ').strip()
except (EOFError, KeyboardInterrupt):
print('\nBye.')
break
if not user_input:
continue
if user_input.lower() == '/clear':
history = []
print('History cleared.\n')
continue
history.append({'role': 'user', 'content': user_input})
prompt_text = tokenizer.apply_chat_template(
history,
tokenize=False,
add_generation_prompt=True,
)
response = generate(model, tokenizer, prompt=prompt_text, max_tokens=512, sampler=sampler)
history.append({'role': 'assistant', 'content': response})
print(f'\nGemma: {response}\n')
mx.clear_cache()