#!/usr/bin/env python3 """LoRA training for LEK Gemma3-4B — memory-limited, correct save.""" import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.utils import tree_flatten, tree_map from functools import partial from types import SimpleNamespace from pathlib import Path from mlx_lm import load from mlx_lm.tuner.utils import linear_to_lora_layers from mlx_lm.tuner.trainer import TrainingArgs, CacheDataset, iterate_batches, default_loss, average_gradients, grad_checkpoint from mlx_lm.tuner.datasets import load_dataset # Metal memory limits. mx.metal.set_memory_limit(24 * 1024**3) mx.metal.set_cache_limit(8 * 1024**3) MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx' DATA_PATH = '/Users/snider/Code/LEM/training/lem/model/gemma3/4b' ADAPTER_PATH = Path('/Volumes/Data/lem/adapters/gemma3-4b-lek') print(f'Model: {MODEL_PATH}') print(f'Data: {DATA_PATH}') print(f'Adapter: {ADAPTER_PATH}') model, tokenizer = load(MODEL_PATH) print('Model loaded.') linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0}) print('LoRA applied.') data_args = SimpleNamespace(data=DATA_PATH, train=True, test=False, hf_dataset=False, mask_prompt=True) train_set, valid_set, _ = load_dataset(data_args, tokenizer) print(f'Train: {len(train_set)} | Valid: {len(valid_set)}') train_set = CacheDataset(train_set) valid_set = CacheDataset(valid_set) ADAPTER_PATH.mkdir(parents=True, exist_ok=True) ADAPTER_FILE = str(ADAPTER_PATH / 'adapters.safetensors') lr_schedule = optim.cosine_decay(2e-5, 300, 1e-6) optimizer = optim.Adam(learning_rate=lr_schedule) ITERS = 300 BATCH = 1 SEQ_LEN = 3072 print(f'\nMetal: mem=24GB, cache=8GB | LoRA: 300 iters, batch 1, LR 2e-5 cosine, rank 16\n') # Grad checkpoint. grad_checkpoint(model.layers[0]) loss_value_and_grad = nn.value_and_grad(model, default_loss) state = [model.state, optimizer.state, mx.random.state] @partial(mx.compile, inputs=state, outputs=state) def step(batch, prev_grad, do_update): (lvalue, toks), grad = loss_value_and_grad(model, *batch) if prev_grad is not None: grad = tree_map(lambda x, y: x + y, grad, prev_grad) if do_update: grad = average_gradients(grad) optimizer.update(model, grad) grad = None return lvalue, toks, grad model.train() losses = 0 trained_tokens = 0 print(f'Starting training..., iters: {ITERS}') for it, batch in zip( range(1, ITERS + 1), iterate_batches(dataset=train_set, batch_size=BATCH, max_seq_length=SEQ_LEN, loop=True), ): lvalue, toks, _ = step(batch, None, True) mx.eval(state) losses += lvalue.item() trained_tokens += toks.item() if it % 5 == 0: mx.clear_cache() if it % 10 == 0: train_loss = losses / 10 peak = mx.get_peak_memory() / 1e9 print(f'Iter {it}: Train loss {train_loss:.3f}, Peak mem {peak:.1f} GB, Tokens {trained_tokens}') losses = 0 if it % 50 == 0 and valid_set is not None: val_loss = 0 val_n = 0 model.eval() for vb, vbatch in zip(range(25), iterate_batches(dataset=valid_set, batch_size=BATCH, max_seq_length=SEQ_LEN)): lv, tv = default_loss(model, *vbatch) val_loss += lv.item() val_n += 1 if val_n > 0: print(f'Iter {it}: Val loss {val_loss/val_n:.3f}') model.train() mx.clear_cache() if it % 100 == 0: weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(ADAPTER_FILE, weights) ckpt = str(ADAPTER_PATH / f'{it:07d}_adapters.safetensors') mx.save_safetensors(ckpt, weights) print(f'Iter {it}: Saved to {ADAPTER_FILE}') # Final save. weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(ADAPTER_FILE, weights) print(f'\nTraining complete. Final adapter: {ADAPTER_FILE}')