#!/usr/bin/env python3 """LoRA training for LEM — direct API, no CLI wrapper.""" import yaml from pathlib import Path from mlx_lm import load from mlx_lm.tuner.utils import linear_to_lora_layers from mlx_lm.tuner.trainer import TrainingArgs, train from mlx_lm.tuner.datasets import load_dataset ROOT = Path(__file__).parent # Load config. with open(ROOT / "training/lem/model/gemma3/4b/lora-config.yaml") as f: cfg = yaml.safe_load(f) print(f"Model: {cfg['model']}") print(f"Data: {cfg['data']}") print(f"Adapter: {cfg['adapter_path']}") # Load model + tokenizer. model, tokenizer = load(str(ROOT / cfg["model"])) # Apply LoRA. lora = cfg["lora_parameters"] linear_to_lora_layers( model, num_lora_layers=cfg["num_layers"], config={"rank": lora["rank"], "dropout": lora["dropout"], "scale": lora["scale"]}, ) p_trainable = sum(v.size for k, v in model.trainable_parameters().items()) p_total = sum(v.size for k, v in model.parameters().items() if not isinstance(v, dict)) print(f"Params: {p_trainable/1e6:.1f}M trainable / {p_total/1e6:.0f}M total") # Load data. train_set, valid_set, _ = load_dataset( data=str(ROOT / cfg["data"]), tokenizer=tokenizer, ) print(f"Train: {len(train_set)} examples") print(f"Valid: {len(valid_set)} examples") # Adapter output path. adapter_path = Path(cfg["adapter_path"]) adapter_path.mkdir(parents=True, exist_ok=True) # Train. args = TrainingArgs( batch_size=cfg["batch_size"], iters=cfg["iters"], val_batches=cfg["val_batches"], steps_per_report=cfg["steps_per_report"], steps_per_eval=cfg["steps_per_eval"], save_every=cfg["save_every"], adapter_file=str(adapter_path / "adapters.safetensors"), max_seq_length=cfg["max_seq_length"], grad_checkpoint=cfg.get("grad_checkpoint", False), learning_rate=cfg["learning_rate"], ) print(f"\nStarting LoRA training: {cfg['iters']} iters, batch {cfg['batch_size']}") print(f"LR: {cfg['learning_rate']}, rank: {lora['rank']}, layers: {cfg['num_layers']}") print() train(model=model, tokenizer=tokenizer, args=args, train_dataset=train_set, val_dataset=valid_set) print("\nDone.")