LEM/train.py
Snider f75458bce6 refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-22 21:00:17 +00:00

67 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""LoRA training for LEM — direct API, no CLI wrapper."""
import yaml
from pathlib import Path
from mlx_lm import load
from mlx_lm.tuner.utils import linear_to_lora_layers
from mlx_lm.tuner.trainer import TrainingArgs, train
from mlx_lm.tuner.datasets import load_dataset
ROOT = Path(__file__).parent
# Load config.
with open(ROOT / "training/lem/model/gemma3/4b/lora-config.yaml") as f:
cfg = yaml.safe_load(f)
print(f"Model: {cfg['model']}")
print(f"Data: {cfg['data']}")
print(f"Adapter: {cfg['adapter_path']}")
# Load model + tokenizer.
model, tokenizer = load(str(ROOT / cfg["model"]))
# Apply LoRA.
lora = cfg["lora_parameters"]
linear_to_lora_layers(
model,
num_lora_layers=cfg["num_layers"],
config={"rank": lora["rank"], "dropout": lora["dropout"], "scale": lora["scale"]},
)
p_trainable = sum(v.size for k, v in model.trainable_parameters().items())
p_total = sum(v.size for k, v in model.parameters().items() if not isinstance(v, dict))
print(f"Params: {p_trainable/1e6:.1f}M trainable / {p_total/1e6:.0f}M total")
# Load data.
train_set, valid_set, _ = load_dataset(
data=str(ROOT / cfg["data"]),
tokenizer=tokenizer,
)
print(f"Train: {len(train_set)} examples")
print(f"Valid: {len(valid_set)} examples")
# Adapter output path.
adapter_path = Path(cfg["adapter_path"])
adapter_path.mkdir(parents=True, exist_ok=True)
# Train.
args = TrainingArgs(
batch_size=cfg["batch_size"],
iters=cfg["iters"],
val_batches=cfg["val_batches"],
steps_per_report=cfg["steps_per_report"],
steps_per_eval=cfg["steps_per_eval"],
save_every=cfg["save_every"],
adapter_file=str(adapter_path / "adapters.safetensors"),
max_seq_length=cfg["max_seq_length"],
grad_checkpoint=cfg.get("grad_checkpoint", False),
learning_rate=cfg["learning_rate"],
)
print(f"\nStarting LoRA training: {cfg['iters']} iters, batch {cfg['batch_size']}")
print(f"LR: {cfg['learning_rate']}, rank: {lora['rank']}, layers: {cfg['num_layers']}")
print()
train(model=model, tokenizer=tokenizer, args=args, train_dataset=train_set, val_dataset=valid_set)
print("\nDone.")