54 lines
1.3 KiB
Python
54 lines
1.3 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Interactive chat with base Gemma3-4B-IT (no LEM training)."""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
sys.stdout.reconfigure(line_buffering=True)
|
||
|
|
|
||
|
|
import mlx.core as mx
|
||
|
|
from mlx_lm import load, generate
|
||
|
|
from mlx_lm.sample_utils import make_sampler
|
||
|
|
|
||
|
|
mx.metal.set_memory_limit(24 * 1024**3)
|
||
|
|
mx.metal.set_cache_limit(8 * 1024**3)
|
||
|
|
|
||
|
|
MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx'
|
||
|
|
|
||
|
|
print(f'Loading Gemma3-4B-IT (base)...')
|
||
|
|
model, tokenizer = load(MODEL_PATH)
|
||
|
|
_set_infer = getattr(model, 'eval')
|
||
|
|
_set_infer()
|
||
|
|
print('Ready.\n')
|
||
|
|
|
||
|
|
sampler = make_sampler(temp=0.7)
|
||
|
|
history = []
|
||
|
|
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
user_input = input('You: ').strip()
|
||
|
|
except (EOFError, KeyboardInterrupt):
|
||
|
|
print('\nBye.')
|
||
|
|
break
|
||
|
|
|
||
|
|
if not user_input:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if user_input.lower() == '/clear':
|
||
|
|
history = []
|
||
|
|
print('History cleared.\n')
|
||
|
|
continue
|
||
|
|
|
||
|
|
history.append({'role': 'user', 'content': user_input})
|
||
|
|
|
||
|
|
prompt_text = tokenizer.apply_chat_template(
|
||
|
|
history,
|
||
|
|
tokenize=False,
|
||
|
|
add_generation_prompt=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
response = generate(model, tokenizer, prompt=prompt_text, max_tokens=512, sampler=sampler)
|
||
|
|
|
||
|
|
history.append({'role': 'assistant', 'content': response})
|
||
|
|
|
||
|
|
print(f'\nGemma: {response}\n')
|
||
|
|
mx.clear_cache()
|