Full v2 scorer benchmark data across 29 models (20 base + 9 LEK-tuned): - P20 (21 probes): All 29 models, 3 conditions each - P100 (101 probes): Top 5 models + LEK-4B, publication-quality data Key findings: - LEK-1B (21.74) beats base 4B/12B/27B at P100 scale — no kernel needed - Emergent realignment resistance: LEK models degrade with runtime kernel - Gemma3-12B + JSON kernel = 23.66 (best kernel-boosted score) - Family lineages: Mistral 3.80→14.58, Qwen regressed then recovered New scripts: ab_test.py (v2 scorer), self_distill.py (curriculum generation), extract_training.py, rephrase_probes.py, Phase 0/1 runners New seeds: P01-P100 merged (101 probes), 404 rephrased variants, 50 creative prompts for Phase 0 baseline lock 27B curriculum design: 4-phase staged training targeting 25+ baseline Co-Authored-By: Virgil <virgil@lethean.io>
106 lines
3.9 KiB
Python
Executable file
106 lines
3.9 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""Extract training data from self-distillation JSONL output.
|
|
|
|
Reads the raw distillation output and produces clean training JSONL
|
|
in the format MLX fine-tuning expects:
|
|
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
|
|
|
Also supports deduplication (keeping highest-scoring response per probe)
|
|
and statistics output.
|
|
|
|
Usage:
|
|
python3 extract_training.py \
|
|
--input /Volumes/Data/lem/training/phase1-raw.jsonl \
|
|
--output /Volumes/Data/lem/training/phase1-train.jsonl \
|
|
--dedup best # best | all | first
|
|
--stats # print statistics
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from collections import defaultdict
|
|
|
|
|
|
def extract(args):
|
|
records = []
|
|
summary = None
|
|
|
|
with open(args.input) as f:
|
|
for line in f:
|
|
obj = json.loads(line)
|
|
if obj.get("type") == "summary":
|
|
summary = obj
|
|
elif obj.get("type") == "training":
|
|
records.append(obj)
|
|
|
|
print(f"Loaded {len(records)} training records from {args.input}", file=sys.stderr)
|
|
|
|
if summary:
|
|
print(f"Source: {summary['model']}", file=sys.stderr)
|
|
print(f"Generated: {summary['total_generated']}, Kept: {summary['total_kept']} ({summary['keep_rate']}%)", file=sys.stderr)
|
|
|
|
# Group by probe
|
|
by_probe = defaultdict(list)
|
|
for r in records:
|
|
by_probe[r["meta"]["probe_id"]].append(r)
|
|
|
|
# Dedup strategy
|
|
output_records = []
|
|
if args.dedup == "best":
|
|
# Keep only the highest-scoring response per probe
|
|
for probe_id, recs in by_probe.items():
|
|
best = max(recs, key=lambda r: r["meta"]["lek_score"])
|
|
output_records.append(best)
|
|
elif args.dedup == "first":
|
|
# Keep first passing response per probe
|
|
for probe_id, recs in by_probe.items():
|
|
output_records.append(recs[0])
|
|
else: # all
|
|
output_records = records
|
|
|
|
# Sort by probe ID for reproducibility
|
|
output_records.sort(key=lambda r: r["meta"]["probe_id"])
|
|
|
|
# Write clean training data
|
|
with open(args.output, "w") as out:
|
|
for r in output_records:
|
|
out.write(json.dumps(r["training"]) + "\n")
|
|
|
|
print(f"Wrote {len(output_records)} training examples to {args.output}", file=sys.stderr)
|
|
|
|
if args.stats:
|
|
scores = [r["meta"]["lek_score"] for r in output_records]
|
|
categories = defaultdict(list)
|
|
for r in output_records:
|
|
categories[r["meta"]["category"]].append(r["meta"]["lek_score"])
|
|
|
|
print(f"\n=== Training Data Statistics ===", file=sys.stderr)
|
|
print(f"Examples: {len(scores)}", file=sys.stderr)
|
|
print(f"Probes: {len(by_probe)}", file=sys.stderr)
|
|
if not scores:
|
|
print(f"No examples passed threshold — try lowering --threshold", file=sys.stderr)
|
|
return
|
|
print(f"Avg score: {sum(scores)/len(scores):.2f}", file=sys.stderr)
|
|
print(f"Min score: {min(scores):.2f}", file=sys.stderr)
|
|
print(f"Max score: {max(scores):.2f}", file=sys.stderr)
|
|
|
|
print(f"\nPer-category:", file=sys.stderr)
|
|
for cat in sorted(categories.keys()):
|
|
cat_scores = categories[cat]
|
|
print(f" {cat}: {len(cat_scores)} examples, avg {sum(cat_scores)/len(cat_scores):.2f}", file=sys.stderr)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Extract training data from distillation output")
|
|
parser.add_argument("--input", required=True, help="Raw distillation JSONL")
|
|
parser.add_argument("--output", required=True, help="Clean training JSONL")
|
|
parser.add_argument("--dedup", default="best", choices=["best", "all", "first"],
|
|
help="Dedup strategy: best (highest score per probe), all, first")
|
|
parser.add_argument("--stats", action="store_true", help="Print statistics")
|
|
args = parser.parse_args()
|
|
extract(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|