LEM/scripts/extract_training.py

107 lines
3.9 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Extract training data from self-distillation JSONL output.
Reads the raw distillation output and produces clean training JSONL
in the format MLX fine-tuning expects:
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
Also supports deduplication (keeping highest-scoring response per probe)
and statistics output.
Usage:
python3 extract_training.py \
--input /Volumes/Data/lem/training/phase1-raw.jsonl \
--output /Volumes/Data/lem/training/phase1-train.jsonl \
--dedup best # best | all | first
--stats # print statistics
"""
import argparse
import json
import sys
from collections import defaultdict
def extract(args):
records = []
summary = None
with open(args.input) as f:
for line in f:
obj = json.loads(line)
if obj.get("type") == "summary":
summary = obj
elif obj.get("type") == "training":
records.append(obj)
print(f"Loaded {len(records)} training records from {args.input}", file=sys.stderr)
if summary:
print(f"Source: {summary['model']}", file=sys.stderr)
print(f"Generated: {summary['total_generated']}, Kept: {summary['total_kept']} ({summary['keep_rate']}%)", file=sys.stderr)
# Group by probe
by_probe = defaultdict(list)
for r in records:
by_probe[r["meta"]["probe_id"]].append(r)
# Dedup strategy
output_records = []
if args.dedup == "best":
# Keep only the highest-scoring response per probe
for probe_id, recs in by_probe.items():
best = max(recs, key=lambda r: r["meta"]["lek_score"])
output_records.append(best)
elif args.dedup == "first":
# Keep first passing response per probe
for probe_id, recs in by_probe.items():
output_records.append(recs[0])
else: # all
output_records = records
# Sort by probe ID for reproducibility
output_records.sort(key=lambda r: r["meta"]["probe_id"])
# Write clean training data
with open(args.output, "w") as out:
for r in output_records:
out.write(json.dumps(r["training"]) + "\n")
print(f"Wrote {len(output_records)} training examples to {args.output}", file=sys.stderr)
if args.stats:
scores = [r["meta"]["lek_score"] for r in output_records]
categories = defaultdict(list)
for r in output_records:
categories[r["meta"]["category"]].append(r["meta"]["lek_score"])
print(f"\n=== Training Data Statistics ===", file=sys.stderr)
print(f"Examples: {len(scores)}", file=sys.stderr)
print(f"Probes: {len(by_probe)}", file=sys.stderr)
if not scores:
print(f"No examples passed threshold — try lowering --threshold", file=sys.stderr)
return
print(f"Avg score: {sum(scores)/len(scores):.2f}", file=sys.stderr)
print(f"Min score: {min(scores):.2f}", file=sys.stderr)
print(f"Max score: {max(scores):.2f}", file=sys.stderr)
print(f"\nPer-category:", file=sys.stderr)
for cat in sorted(categories.keys()):
cat_scores = categories[cat]
print(f" {cat}: {len(cat_scores)} examples, avg {sum(cat_scores)/len(cat_scores):.2f}", file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description="Extract training data from distillation output")
parser.add_argument("--input", required=True, help="Raw distillation JSONL")
parser.add_argument("--output", required=True, help="Clean training JSONL")
parser.add_argument("--dedup", default="best", choices=["best", "all", "first"],
help="Dedup strategy: best (highest score per probe), all, first")
parser.add_argument("--stats", action="store_true", help="Print statistics")
args = parser.parse_args()
extract(args)
if __name__ == "__main__":
main()