Exact reproduction of all 7 CL-BPL phases for Gemma3-12B: - P0: LEK sandwich ethics (400 iters, LR 2e-5) - P1: Zen composure (300 iters, LR 1e-5) - P2: LEK sandwich reinforcement (300 iters, LR 1e-5) - P3: Freeflow multi-source (300 iters, LR 1e-5) - P4: 1B teacher tension distillation (300 iters, LR 1e-5) - P5: 1B teacher creative distillation (300 iters, LR 1e-5) - P6: Golden set graduation (13479 iters, LR 1e-5) Only model-size differences from 4B: 48GB/12GB Metal limits, 24 LoRA layers (vs 16), 12B base model path. All phases score at checkpoint cadence via lem-scorer. Previous wrong 12B models preserved as -no-axioms control group. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
116 lines
3.8 KiB
Python
116 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
"""LoRA training for LEK Gemma3-4B — memory-limited, correct save."""
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as optim
|
|
from mlx.utils import tree_flatten, tree_map
|
|
from functools import partial
|
|
from types import SimpleNamespace
|
|
from pathlib import Path
|
|
from mlx_lm import load
|
|
from mlx_lm.tuner.utils import linear_to_lora_layers
|
|
from mlx_lm.tuner.trainer import TrainingArgs, CacheDataset, iterate_batches, default_loss, average_gradients, grad_checkpoint
|
|
from mlx_lm.tuner.datasets import load_dataset
|
|
|
|
# Metal memory limits.
|
|
mx.metal.set_memory_limit(24 * 1024**3)
|
|
mx.metal.set_cache_limit(8 * 1024**3)
|
|
|
|
MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx'
|
|
DATA_PATH = '/Users/snider/Code/LEM/training/lem/model/gemma3/4b'
|
|
ADAPTER_PATH = Path('/Volumes/Data/lem/adapters/gemma3-4b-lek')
|
|
|
|
print(f'Model: {MODEL_PATH}')
|
|
print(f'Data: {DATA_PATH}')
|
|
print(f'Adapter: {ADAPTER_PATH}')
|
|
|
|
model, tokenizer = load(MODEL_PATH)
|
|
print('Model loaded.')
|
|
|
|
linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0})
|
|
print('LoRA applied.')
|
|
|
|
data_args = SimpleNamespace(data=DATA_PATH, train=True, test=False, hf_dataset=False, mask_prompt=True)
|
|
train_set, valid_set, _ = load_dataset(data_args, tokenizer)
|
|
print(f'Train: {len(train_set)} | Valid: {len(valid_set)}')
|
|
train_set = CacheDataset(train_set)
|
|
valid_set = CacheDataset(valid_set)
|
|
|
|
ADAPTER_PATH.mkdir(parents=True, exist_ok=True)
|
|
ADAPTER_FILE = str(ADAPTER_PATH / 'adapters.safetensors')
|
|
|
|
lr_schedule = optim.cosine_decay(2e-5, 300, 1e-6)
|
|
optimizer = optim.Adam(learning_rate=lr_schedule)
|
|
|
|
ITERS = 300
|
|
BATCH = 1
|
|
SEQ_LEN = 3072
|
|
|
|
print(f'\nMetal: mem=24GB, cache=8GB | LoRA: 300 iters, batch 1, LR 2e-5 cosine, rank 16\n')
|
|
|
|
# Grad checkpoint.
|
|
grad_checkpoint(model.layers[0])
|
|
|
|
loss_value_and_grad = nn.value_and_grad(model, default_loss)
|
|
state = [model.state, optimizer.state, mx.random.state]
|
|
|
|
@partial(mx.compile, inputs=state, outputs=state)
|
|
def step(batch, prev_grad, do_update):
|
|
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
|
|
if prev_grad is not None:
|
|
grad = tree_map(lambda x, y: x + y, grad, prev_grad)
|
|
if do_update:
|
|
grad = average_gradients(grad)
|
|
optimizer.update(model, grad)
|
|
grad = None
|
|
return lvalue, toks, grad
|
|
|
|
model.train()
|
|
losses = 0
|
|
trained_tokens = 0
|
|
|
|
print(f'Starting training..., iters: {ITERS}')
|
|
|
|
for it, batch in zip(
|
|
range(1, ITERS + 1),
|
|
iterate_batches(dataset=train_set, batch_size=BATCH, max_seq_length=SEQ_LEN, loop=True),
|
|
):
|
|
lvalue, toks, _ = step(batch, None, True)
|
|
mx.eval(state)
|
|
losses += lvalue.item()
|
|
trained_tokens += toks.item()
|
|
|
|
if it % 5 == 0:
|
|
mx.clear_cache()
|
|
|
|
if it % 10 == 0:
|
|
train_loss = losses / 10
|
|
peak = mx.get_peak_memory() / 1e9
|
|
print(f'Iter {it}: Train loss {train_loss:.3f}, Peak mem {peak:.1f} GB, Tokens {trained_tokens}')
|
|
losses = 0
|
|
|
|
if it % 50 == 0 and valid_set is not None:
|
|
val_loss = 0
|
|
val_n = 0
|
|
model.eval()
|
|
for vb, vbatch in zip(range(25), iterate_batches(dataset=valid_set, batch_size=BATCH, max_seq_length=SEQ_LEN)):
|
|
lv, tv = default_loss(model, *vbatch)
|
|
val_loss += lv.item()
|
|
val_n += 1
|
|
if val_n > 0:
|
|
print(f'Iter {it}: Val loss {val_loss/val_n:.3f}')
|
|
model.train()
|
|
mx.clear_cache()
|
|
|
|
if it % 100 == 0:
|
|
weights = dict(tree_flatten(model.trainable_parameters()))
|
|
mx.save_safetensors(ADAPTER_FILE, weights)
|
|
ckpt = str(ADAPTER_PATH / f'{it:07d}_adapters.safetensors')
|
|
mx.save_safetensors(ckpt, weights)
|
|
print(f'Iter {it}: Saved to {ADAPTER_FILE}')
|
|
|
|
# Final save.
|
|
weights = dict(tree_flatten(model.trainable_parameters()))
|
|
mx.save_safetensors(ADAPTER_FILE, weights)
|
|
print(f'\nTraining complete. Final adapter: {ADAPTER_FILE}')
|