1
0
Fork 0
forked from lthn/LEM
LEM/scripts/extract_training.py
Snider 7bea00a401 feat: LEK-1 kernel A/B test — 29 models, P100 validation, curriculum pipeline
Full v2 scorer benchmark data across 29 models (20 base + 9 LEK-tuned):
- P20 (21 probes): All 29 models, 3 conditions each
- P100 (101 probes): Top 5 models + LEK-4B, publication-quality data

Key findings:
- LEK-1B (21.74) beats base 4B/12B/27B at P100 scale — no kernel needed
- Emergent realignment resistance: LEK models degrade with runtime kernel
- Gemma3-12B + JSON kernel = 23.66 (best kernel-boosted score)
- Family lineages: Mistral 3.80→14.58, Qwen regressed then recovered

New scripts: ab_test.py (v2 scorer), self_distill.py (curriculum generation),
extract_training.py, rephrase_probes.py, Phase 0/1 runners

New seeds: P01-P100 merged (101 probes), 404 rephrased variants,
50 creative prompts for Phase 0 baseline lock

27B curriculum design: 4-phase staged training targeting 25+ baseline

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-19 11:32:26 +00:00

106 lines
3.9 KiB
Python
Executable file

#!/usr/bin/env python3
"""Extract training data from self-distillation JSONL output.
Reads the raw distillation output and produces clean training JSONL
in the format MLX fine-tuning expects:
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
Also supports deduplication (keeping highest-scoring response per probe)
and statistics output.
Usage:
python3 extract_training.py \
--input /Volumes/Data/lem/training/phase1-raw.jsonl \
--output /Volumes/Data/lem/training/phase1-train.jsonl \
--dedup best # best | all | first
--stats # print statistics
"""
import argparse
import json
import sys
from collections import defaultdict
def extract(args):
records = []
summary = None
with open(args.input) as f:
for line in f:
obj = json.loads(line)
if obj.get("type") == "summary":
summary = obj
elif obj.get("type") == "training":
records.append(obj)
print(f"Loaded {len(records)} training records from {args.input}", file=sys.stderr)
if summary:
print(f"Source: {summary['model']}", file=sys.stderr)
print(f"Generated: {summary['total_generated']}, Kept: {summary['total_kept']} ({summary['keep_rate']}%)", file=sys.stderr)
# Group by probe
by_probe = defaultdict(list)
for r in records:
by_probe[r["meta"]["probe_id"]].append(r)
# Dedup strategy
output_records = []
if args.dedup == "best":
# Keep only the highest-scoring response per probe
for probe_id, recs in by_probe.items():
best = max(recs, key=lambda r: r["meta"]["lek_score"])
output_records.append(best)
elif args.dedup == "first":
# Keep first passing response per probe
for probe_id, recs in by_probe.items():
output_records.append(recs[0])
else: # all
output_records = records
# Sort by probe ID for reproducibility
output_records.sort(key=lambda r: r["meta"]["probe_id"])
# Write clean training data
with open(args.output, "w") as out:
for r in output_records:
out.write(json.dumps(r["training"]) + "\n")
print(f"Wrote {len(output_records)} training examples to {args.output}", file=sys.stderr)
if args.stats:
scores = [r["meta"]["lek_score"] for r in output_records]
categories = defaultdict(list)
for r in output_records:
categories[r["meta"]["category"]].append(r["meta"]["lek_score"])
print(f"\n=== Training Data Statistics ===", file=sys.stderr)
print(f"Examples: {len(scores)}", file=sys.stderr)
print(f"Probes: {len(by_probe)}", file=sys.stderr)
if not scores:
print(f"No examples passed threshold — try lowering --threshold", file=sys.stderr)
return
print(f"Avg score: {sum(scores)/len(scores):.2f}", file=sys.stderr)
print(f"Min score: {min(scores):.2f}", file=sys.stderr)
print(f"Max score: {max(scores):.2f}", file=sys.stderr)
print(f"\nPer-category:", file=sys.stderr)
for cat in sorted(categories.keys()):
cat_scores = categories[cat]
print(f" {cat}: {len(cat_scores)} examples, avg {sum(cat_scores)/len(cat_scores):.2f}", file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description="Extract training data from distillation output")
parser.add_argument("--input", required=True, help="Raw distillation JSONL")
parser.add_argument("--output", required=True, help="Clean training JSONL")
parser.add_argument("--dedup", default="best", choices=["best", "all", "first"],
help="Dedup strategy: best (highest score per probe), all, first")
parser.add_argument("--stats", action="store_true", help="Print statistics")
args = parser.parse_args()
extract(args)
if __name__ == "__main__":
main()