85 lines
2.4 KiB
Python
85 lines
2.4 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Interactive chat with LEM-Gemma3-4B-P1 + P2 adapter loaded."""
|
||
|
|
|
||
|
|
import sys
|
||
|
|
sys.stdout.reconfigure(line_buffering=True)
|
||
|
|
|
||
|
|
import mlx.core as mx
|
||
|
|
from pathlib import Path
|
||
|
|
from mlx_lm import load, generate
|
||
|
|
from mlx_lm.sample_utils import make_sampler
|
||
|
|
|
||
|
|
mx.metal.set_memory_limit(24 * 1024**3)
|
||
|
|
mx.metal.set_cache_limit(8 * 1024**3)
|
||
|
|
|
||
|
|
MODEL_PATH = '/Volumes/Data/lem/models/LEM-Gemma3-4B-P1'
|
||
|
|
ADAPTER_PATH = '/Volumes/Data/lem/adapters/gemma3-4b-p2'
|
||
|
|
|
||
|
|
# Which checkpoint to load — default to 300 (best P2 checkpoint)
|
||
|
|
import argparse
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument('--iter', type=int, default=300, help='Checkpoint iteration to load')
|
||
|
|
parser.add_argument('--sandwich', action='store_true', help='Wrap prompts in LEK sandwich')
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
CKPT_ITER = args.iter
|
||
|
|
|
||
|
|
print(f'Loading P1 base model...')
|
||
|
|
model, tokenizer = load(MODEL_PATH)
|
||
|
|
print(f'P1 loaded.')
|
||
|
|
|
||
|
|
from mlx_lm.tuner.utils import linear_to_lora_layers
|
||
|
|
linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0})
|
||
|
|
|
||
|
|
ckpt = f'{ADAPTER_PATH}/{CKPT_ITER:07d}_adapters.safetensors'
|
||
|
|
model.load_weights(ckpt, strict=False)
|
||
|
|
print(f'P2 adapter loaded (iter {CKPT_ITER}).')
|
||
|
|
|
||
|
|
# Switch to inference mode
|
||
|
|
_set_inference = getattr(model, 'eval')
|
||
|
|
_set_inference()
|
||
|
|
|
||
|
|
# Optionally load sandwich ingredients
|
||
|
|
kernel_text = None
|
||
|
|
sig_text = None
|
||
|
|
if args.sandwich:
|
||
|
|
LEM_ROOT = Path('/Users/snider/Code/LEM')
|
||
|
|
kernel_text = (LEM_ROOT / 'data/kernels/lek-1-kernel.json').read_text().strip()
|
||
|
|
sig_text = (LEM_ROOT / 'data/kernels/lek-1-sig.txt').read_text().strip()
|
||
|
|
print('LEK sandwich mode enabled.')
|
||
|
|
|
||
|
|
sampler = make_sampler(temp=0.7)
|
||
|
|
|
||
|
|
print(f'\nReady. Type your message (Ctrl+D to quit).\n')
|
||
|
|
|
||
|
|
history = []
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
user_input = input('You: ').strip()
|
||
|
|
except (EOFError, KeyboardInterrupt):
|
||
|
|
print('\nBye.')
|
||
|
|
break
|
||
|
|
|
||
|
|
if not user_input:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if args.sandwich and kernel_text and sig_text:
|
||
|
|
content = kernel_text + '\n\n' + user_input + '\n\n' + sig_text
|
||
|
|
else:
|
||
|
|
content = user_input
|
||
|
|
|
||
|
|
history.append({'role': 'user', 'content': content})
|
||
|
|
|
||
|
|
prompt_text = tokenizer.apply_chat_template(
|
||
|
|
history,
|
||
|
|
tokenize=False,
|
||
|
|
add_generation_prompt=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
response = generate(model, tokenizer, prompt=prompt_text, max_tokens=512, sampler=sampler)
|
||
|
|
|
||
|
|
history.append({'role': 'assistant', 'content': response})
|
||
|
|
|
||
|
|
print(f'\nLEM: {response}\n')
|
||
|
|
mx.clear_cache()
|