#!/usr/bin/env python3 """Interactive chat with LEM-Gemma3-4B-P1 + P2 adapter loaded.""" import sys sys.stdout.reconfigure(line_buffering=True) import mlx.core as mx from pathlib import Path from mlx_lm import load, generate from mlx_lm.sample_utils import make_sampler mx.metal.set_memory_limit(24 * 1024**3) mx.metal.set_cache_limit(8 * 1024**3) MODEL_PATH = '/Volumes/Data/lem/models/LEM-Gemma3-4B-P1' ADAPTER_PATH = '/Volumes/Data/lem/adapters/gemma3-4b-p2' # Which checkpoint to load — default to 300 (best P2 checkpoint) import argparse parser = argparse.ArgumentParser() parser.add_argument('--iter', type=int, default=300, help='Checkpoint iteration to load') parser.add_argument('--sandwich', action='store_true', help='Wrap prompts in LEK sandwich') args = parser.parse_args() CKPT_ITER = args.iter print(f'Loading P1 base model...') model, tokenizer = load(MODEL_PATH) print(f'P1 loaded.') from mlx_lm.tuner.utils import linear_to_lora_layers linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0}) ckpt = f'{ADAPTER_PATH}/{CKPT_ITER:07d}_adapters.safetensors' model.load_weights(ckpt, strict=False) print(f'P2 adapter loaded (iter {CKPT_ITER}).') # Switch to inference mode _set_inference = getattr(model, 'eval') _set_inference() # Optionally load sandwich ingredients kernel_text = None sig_text = None if args.sandwich: LEM_ROOT = Path('/Users/snider/Code/LEM') kernel_text = (LEM_ROOT / 'data/kernels/lek-1-kernel.json').read_text().strip() sig_text = (LEM_ROOT / 'data/kernels/lek-1-sig.txt').read_text().strip() print('LEK sandwich mode enabled.') sampler = make_sampler(temp=0.7) print(f'\nReady. Type your message (Ctrl+D to quit).\n') history = [] while True: try: user_input = input('You: ').strip() except (EOFError, KeyboardInterrupt): print('\nBye.') break if not user_input: continue if args.sandwich and kernel_text and sig_text: content = kernel_text + '\n\n' + user_input + '\n\n' + sig_text else: content = user_input history.append({'role': 'user', 'content': content}) prompt_text = tokenizer.apply_chat_template( history, tokenize=False, add_generation_prompt=True, ) response = generate(model, tokenizer, prompt=prompt_text, max_tokens=512, sampler=sampler) history.append({'role': 'assistant', 'content': response}) print(f'\nLEM: {response}\n') mx.clear_cache()