Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.
Co-Authored-By: Virgil <virgil@lethean.io>
67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
"""LoRA training for LEM — direct API, no CLI wrapper."""
|
|
|
|
import yaml
|
|
from pathlib import Path
|
|
from mlx_lm import load
|
|
from mlx_lm.tuner.utils import linear_to_lora_layers
|
|
from mlx_lm.tuner.trainer import TrainingArgs, train
|
|
from mlx_lm.tuner.datasets import load_dataset
|
|
|
|
ROOT = Path(__file__).parent
|
|
|
|
# Load config.
|
|
with open(ROOT / "training/lem/model/gemma3/4b/lora-config.yaml") as f:
|
|
cfg = yaml.safe_load(f)
|
|
|
|
print(f"Model: {cfg['model']}")
|
|
print(f"Data: {cfg['data']}")
|
|
print(f"Adapter: {cfg['adapter_path']}")
|
|
|
|
# Load model + tokenizer.
|
|
model, tokenizer = load(str(ROOT / cfg["model"]))
|
|
|
|
# Apply LoRA.
|
|
lora = cfg["lora_parameters"]
|
|
linear_to_lora_layers(
|
|
model,
|
|
num_lora_layers=cfg["num_layers"],
|
|
config={"rank": lora["rank"], "dropout": lora["dropout"], "scale": lora["scale"]},
|
|
)
|
|
|
|
p_trainable = sum(v.size for k, v in model.trainable_parameters().items())
|
|
p_total = sum(v.size for k, v in model.parameters().items() if not isinstance(v, dict))
|
|
print(f"Params: {p_trainable/1e6:.1f}M trainable / {p_total/1e6:.0f}M total")
|
|
|
|
# Load data.
|
|
train_set, valid_set, _ = load_dataset(
|
|
data=str(ROOT / cfg["data"]),
|
|
tokenizer=tokenizer,
|
|
)
|
|
print(f"Train: {len(train_set)} examples")
|
|
print(f"Valid: {len(valid_set)} examples")
|
|
|
|
# Adapter output path.
|
|
adapter_path = Path(cfg["adapter_path"])
|
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Train.
|
|
args = TrainingArgs(
|
|
batch_size=cfg["batch_size"],
|
|
iters=cfg["iters"],
|
|
val_batches=cfg["val_batches"],
|
|
steps_per_report=cfg["steps_per_report"],
|
|
steps_per_eval=cfg["steps_per_eval"],
|
|
save_every=cfg["save_every"],
|
|
adapter_file=str(adapter_path / "adapters.safetensors"),
|
|
max_seq_length=cfg["max_seq_length"],
|
|
grad_checkpoint=cfg.get("grad_checkpoint", False),
|
|
learning_rate=cfg["learning_rate"],
|
|
)
|
|
|
|
print(f"\nStarting LoRA training: {cfg['iters']} iters, batch {cfg['batch_size']}")
|
|
print(f"LR: {cfg['learning_rate']}, rank: {lora['rank']}, layers: {cfg['num_layers']}")
|
|
print()
|
|
|
|
train(model=model, tokenizer=tokenizer, args=args, train_dataset=train_set, val_dataset=valid_set)
|
|
print("\nDone.")
|