#!/usr/bin/env python3 """P0 LoRA training for Gemma3-4B — LEK sandwich built in code.""" import sys sys.stdout.reconfigure(line_buffering=True) import json import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.utils import tree_flatten, tree_map from functools import partial from pathlib import Path from mlx_lm import load from mlx_lm.tuner.utils import linear_to_lora_layers from mlx_lm.tuner.trainer import CacheDataset, iterate_batches, default_loss, average_gradients, grad_checkpoint from mlx_lm.tuner.datasets import ChatDataset # ── Metal memory limits ────────────────────────────────────────────── mx.metal.set_memory_limit(24 * 1024**3) mx.metal.set_cache_limit(8 * 1024**3) # ── Paths ──────────────────────────────────────────────────────────── LEM_ROOT = Path('/Users/snider/Code/LEM') MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx' ADAPTER_PATH = Path('/Volumes/Data/lem/adapters/gemma3-4b-p0') # ── Build sandwich data in memory ──────────────────────────────────── print('Building P0 sandwich data...') # Read kernel JSON as raw string (the model sees the full JSON) kernel_text = (LEM_ROOT / 'data/kernels/lek-1-kernel.json').read_text().strip() # Read sig quote sig_text = (LEM_ROOT / 'data/kernels/lek-1-sig.txt').read_text().strip() # Read 404 probes with open(LEM_ROOT / 'training/lem/ethics/core.json') as f: probes = json.load(f) # Read existing 1B responses (bare format — prompt matched by index) responses = [] with open(LEM_ROOT / 'training/lem/model/gemma3/4b/lesson-lem1b.jsonl') as f: for line in f: line = line.strip() if line: responses.append(json.loads(line)) print(f' Probes: {len(probes)} | Responses: {len(responses)}') # Build sandwich messages: kernel + probe + sig → user, response → assistant train_data = [] skipped = 0 for i, probe in enumerate(probes): if i >= len(responses): skipped += 1 continue resp = responses[i] assistant_content = resp['messages'][1]['content'] # Sandwich: kernel JSON + probe + sig sandwich = kernel_text + '\n\n' + probe['prompt'] + '\n\n' + sig_text train_data.append({ 'messages': [ {'role': 'user', 'content': sandwich}, {'role': 'assistant', 'content': assistant_content}, ] }) print(f' Training examples: {len(train_data)} (skipped {skipped})') # 90/10 train/valid split split = int(len(train_data) * 0.9) train_messages = train_data[:split] valid_messages = train_data[split:] print(f' Train: {len(train_messages)} | Valid: {len(valid_messages)}') # ── Load model ─────────────────────────────────────────────────────── print(f'\nModel: {MODEL_PATH}') model, tokenizer = load(MODEL_PATH) print('Model loaded.') # ── Apply LoRA ─────────────────────────────────────────────────────── linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0}) print('LoRA applied (16 layers, rank 16).') # ── Create datasets directly in memory ─────────────────────────────── train_set = CacheDataset(ChatDataset(train_messages, tokenizer, mask_prompt=True)) valid_set = CacheDataset(ChatDataset(valid_messages, tokenizer, mask_prompt=True)) print(f'Datasets created: train={len(train_set)}, valid={len(valid_set)}') # ── Training config ────────────────────────────────────────────────── ITERS = 400 BATCH = 1 SEQ_LEN = 3072 ADAPTER_PATH.mkdir(parents=True, exist_ok=True) ADAPTER_FILE = str(ADAPTER_PATH / 'adapters.safetensors') lr_schedule = optim.cosine_decay(2e-5, ITERS, 1e-6) optimizer = optim.Adam(learning_rate=lr_schedule) print(f'\nP0 Training: {ITERS} iters, batch {BATCH}, LR 2e-5 cosine, rank 16, seq {SEQ_LEN}') # Grad checkpoint for memory. grad_checkpoint(model.layers[0]) loss_value_and_grad = nn.value_and_grad(model, default_loss) state = [model.state, optimizer.state, mx.random.state] evaluate = mx.eval # MLX array evaluation function @partial(mx.compile, inputs=state, outputs=state) def step(batch, prev_grad, do_update): (lvalue, toks), grad = loss_value_and_grad(model, *batch) if prev_grad is not None: grad = tree_map(lambda x, y: x + y, grad, prev_grad) if do_update: grad = average_gradients(grad) optimizer.update(model, grad) grad = None return lvalue, toks, grad # ── Train ──────────────────────────────────────────────────────────── model.train() losses = 0 trained_tokens = 0 print(f'Starting P0 training...\n') for it, batch in zip( range(1, ITERS + 1), iterate_batches(dataset=train_set, batch_size=BATCH, max_seq_length=SEQ_LEN, loop=True), ): lvalue, toks, _ = step(batch, None, True) evaluate(state) losses += lvalue.item() trained_tokens += toks.item() if it % 5 == 0: mx.clear_cache() if it % 10 == 0: train_loss = losses / 10 peak = mx.get_peak_memory() / 1e9 print(f'Iter {it:>4d}: loss {train_loss:.3f} | peak {peak:.1f} GB | tokens {trained_tokens}') losses = 0 if it % 50 == 0 and valid_set is not None: val_loss = 0 val_n = 0 model.eval() for vb, vbatch in zip(range(25), iterate_batches(dataset=valid_set, batch_size=BATCH, max_seq_length=SEQ_LEN)): lv, tv = default_loss(model, *vbatch) val_loss += lv.item() val_n += 1 if val_n > 0: print(f'Iter {it:>4d}: val_loss {val_loss/val_n:.3f}') model.train() mx.clear_cache() if it % 100 == 0: weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(ADAPTER_FILE, weights) ckpt = str(ADAPTER_PATH / f'{it:07d}_adapters.safetensors') mx.save_safetensors(ckpt, weights) print(f'Iter {it:>4d}: checkpoint saved') # ── Final save ─────────────────────────────────────────────────────── weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(ADAPTER_FILE, weights) # Write adapter config so mlx_lm.load() can reload the adapter. adapter_config = { 'fine_tune_type': 'lora', 'num_layers': 16, 'lora_parameters': {'rank': 16, 'dropout': 0.05, 'scale': 32.0}, } with open(ADAPTER_PATH / 'adapter_config.json', 'w') as f: json.dump(adapter_config, f, indent=2) print(f'\nP0 training complete. Adapter: {ADAPTER_FILE}') print(f'Total tokens: {trained_tokens}')