LEM/scripts/train-4b-p0.py
Snider 74ef174ec8 feat: add faithful 12B training scripts (P0-P6) — 1:1 port of 4B curriculum
Exact reproduction of all 7 CL-BPL phases for Gemma3-12B:
- P0: LEK sandwich ethics (400 iters, LR 2e-5)
- P1: Zen composure (300 iters, LR 1e-5)
- P2: LEK sandwich reinforcement (300 iters, LR 1e-5)
- P3: Freeflow multi-source (300 iters, LR 1e-5)
- P4: 1B teacher tension distillation (300 iters, LR 1e-5)
- P5: 1B teacher creative distillation (300 iters, LR 1e-5)
- P6: Golden set graduation (13479 iters, LR 1e-5)

Only model-size differences from 4B: 48GB/12GB Metal limits,
24 LoRA layers (vs 16), 12B base model path.

All phases score at checkpoint cadence via lem-scorer.
Previous wrong 12B models preserved as -no-axioms control group.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 20:44:03 +00:00

187 lines
7.1 KiB
Python

#!/usr/bin/env python3
"""P0 LoRA training for Gemma3-4B — LEK sandwich built in code."""
import sys
sys.stdout.reconfigure(line_buffering=True)
import json
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten, tree_map
from functools import partial
from pathlib import Path
from mlx_lm import load
from mlx_lm.tuner.utils import linear_to_lora_layers
from mlx_lm.tuner.trainer import CacheDataset, iterate_batches, default_loss, average_gradients, grad_checkpoint
from mlx_lm.tuner.datasets import ChatDataset
# ── Metal memory limits ──────────────────────────────────────────────
mx.metal.set_memory_limit(24 * 1024**3)
mx.metal.set_cache_limit(8 * 1024**3)
# ── Paths ────────────────────────────────────────────────────────────
LEM_ROOT = Path('/Users/snider/Code/LEM')
MODEL_PATH = '/Volumes/Data/lem/gemma-3-4b-it-mlx'
ADAPTER_PATH = Path('/Volumes/Data/lem/adapters/gemma3-4b-p0')
# ── Build sandwich data in memory ────────────────────────────────────
print('Building P0 sandwich data...')
# Read kernel JSON as raw string (the model sees the full JSON)
kernel_text = (LEM_ROOT / 'data/kernels/lek-1-kernel.json').read_text().strip()
# Read sig quote
sig_text = (LEM_ROOT / 'data/kernels/lek-1-sig.txt').read_text().strip()
# Read 404 probes
with open(LEM_ROOT / 'training/lem/ethics/core.json') as f:
probes = json.load(f)
# Read existing 1B responses (bare format — prompt matched by index)
responses = []
with open(LEM_ROOT / 'training/lem/model/gemma3/4b/lesson-lem1b.jsonl') as f:
for line in f:
line = line.strip()
if line:
responses.append(json.loads(line))
print(f' Probes: {len(probes)} | Responses: {len(responses)}')
# Build sandwich messages: kernel + probe + sig → user, response → assistant
train_data = []
skipped = 0
for i, probe in enumerate(probes):
if i >= len(responses):
skipped += 1
continue
resp = responses[i]
assistant_content = resp['messages'][1]['content']
# Sandwich: kernel JSON + probe + sig
sandwich = kernel_text + '\n\n' + probe['prompt'] + '\n\n' + sig_text
train_data.append({
'messages': [
{'role': 'user', 'content': sandwich},
{'role': 'assistant', 'content': assistant_content},
]
})
print(f' Training examples: {len(train_data)} (skipped {skipped})')
# 90/10 train/valid split
split = int(len(train_data) * 0.9)
train_messages = train_data[:split]
valid_messages = train_data[split:]
print(f' Train: {len(train_messages)} | Valid: {len(valid_messages)}')
# ── Load model ───────────────────────────────────────────────────────
print(f'\nModel: {MODEL_PATH}')
model, tokenizer = load(MODEL_PATH)
print('Model loaded.')
# ── Apply LoRA ───────────────────────────────────────────────────────
linear_to_lora_layers(model, num_layers=16, config={'rank': 16, 'dropout': 0.05, 'scale': 32.0})
print('LoRA applied (16 layers, rank 16).')
# ── Create datasets directly in memory ───────────────────────────────
train_set = CacheDataset(ChatDataset(train_messages, tokenizer, mask_prompt=True))
valid_set = CacheDataset(ChatDataset(valid_messages, tokenizer, mask_prompt=True))
print(f'Datasets created: train={len(train_set)}, valid={len(valid_set)}')
# ── Training config ──────────────────────────────────────────────────
ITERS = 400
BATCH = 1
SEQ_LEN = 3072
ADAPTER_PATH.mkdir(parents=True, exist_ok=True)
ADAPTER_FILE = str(ADAPTER_PATH / 'adapters.safetensors')
lr_schedule = optim.cosine_decay(2e-5, ITERS, 1e-6)
optimizer = optim.Adam(learning_rate=lr_schedule)
print(f'\nP0 Training: {ITERS} iters, batch {BATCH}, LR 2e-5 cosine, rank 16, seq {SEQ_LEN}')
# Grad checkpoint for memory.
grad_checkpoint(model.layers[0])
loss_value_and_grad = nn.value_and_grad(model, default_loss)
state = [model.state, optimizer.state, mx.random.state]
evaluate = mx.eval # MLX array evaluation function
@partial(mx.compile, inputs=state, outputs=state)
def step(batch, prev_grad, do_update):
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
if prev_grad is not None:
grad = tree_map(lambda x, y: x + y, grad, prev_grad)
if do_update:
grad = average_gradients(grad)
optimizer.update(model, grad)
grad = None
return lvalue, toks, grad
# ── Train ────────────────────────────────────────────────────────────
model.train()
losses = 0
trained_tokens = 0
print(f'Starting P0 training...\n')
for it, batch in zip(
range(1, ITERS + 1),
iterate_batches(dataset=train_set, batch_size=BATCH, max_seq_length=SEQ_LEN, loop=True),
):
lvalue, toks, _ = step(batch, None, True)
evaluate(state)
losses += lvalue.item()
trained_tokens += toks.item()
if it % 5 == 0:
mx.clear_cache()
if it % 10 == 0:
train_loss = losses / 10
peak = mx.get_peak_memory() / 1e9
print(f'Iter {it:>4d}: loss {train_loss:.3f} | peak {peak:.1f} GB | tokens {trained_tokens}')
losses = 0
if it % 50 == 0 and valid_set is not None:
val_loss = 0
val_n = 0
model.eval()
for vb, vbatch in zip(range(25), iterate_batches(dataset=valid_set, batch_size=BATCH, max_seq_length=SEQ_LEN)):
lv, tv = default_loss(model, *vbatch)
val_loss += lv.item()
val_n += 1
if val_n > 0:
print(f'Iter {it:>4d}: val_loss {val_loss/val_n:.3f}')
model.train()
mx.clear_cache()
if it % 100 == 0:
weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(ADAPTER_FILE, weights)
ckpt = str(ADAPTER_PATH / f'{it:07d}_adapters.safetensors')
mx.save_safetensors(ckpt, weights)
print(f'Iter {it:>4d}: checkpoint saved')
# ── Final save ───────────────────────────────────────────────────────
weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(ADAPTER_FILE, weights)
# Write adapter config so mlx_lm.load() can reload the adapter.
adapter_config = {
'fine_tune_type': 'lora',
'num_layers': 16,
'lora_parameters': {'rank': 16, 'dropout': 0.05, 'scale': 32.0},
}
with open(ADAPTER_PATH / 'adapter_config.json', 'w') as f:
json.dump(adapter_config, f, indent=2)
print(f'\nP0 training complete. Adapter: {ADAPTER_FILE}')
print(f'Total tokens: {trained_tokens}')