107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Extract training data from self-distillation JSONL output.
|
||
|
|
|
||
|
|
Reads the raw distillation output and produces clean training JSONL
|
||
|
|
in the format MLX fine-tuning expects:
|
||
|
|
{"messages": [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
|
||
|
|
|
||
|
|
Also supports deduplication (keeping highest-scoring response per probe)
|
||
|
|
and statistics output.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
python3 extract_training.py \
|
||
|
|
--input /Volumes/Data/lem/training/phase1-raw.jsonl \
|
||
|
|
--output /Volumes/Data/lem/training/phase1-train.jsonl \
|
||
|
|
--dedup best # best | all | first
|
||
|
|
--stats # print statistics
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import json
|
||
|
|
import sys
|
||
|
|
from collections import defaultdict
|
||
|
|
|
||
|
|
|
||
|
|
def extract(args):
|
||
|
|
records = []
|
||
|
|
summary = None
|
||
|
|
|
||
|
|
with open(args.input) as f:
|
||
|
|
for line in f:
|
||
|
|
obj = json.loads(line)
|
||
|
|
if obj.get("type") == "summary":
|
||
|
|
summary = obj
|
||
|
|
elif obj.get("type") == "training":
|
||
|
|
records.append(obj)
|
||
|
|
|
||
|
|
print(f"Loaded {len(records)} training records from {args.input}", file=sys.stderr)
|
||
|
|
|
||
|
|
if summary:
|
||
|
|
print(f"Source: {summary['model']}", file=sys.stderr)
|
||
|
|
print(f"Generated: {summary['total_generated']}, Kept: {summary['total_kept']} ({summary['keep_rate']}%)", file=sys.stderr)
|
||
|
|
|
||
|
|
# Group by probe
|
||
|
|
by_probe = defaultdict(list)
|
||
|
|
for r in records:
|
||
|
|
by_probe[r["meta"]["probe_id"]].append(r)
|
||
|
|
|
||
|
|
# Dedup strategy
|
||
|
|
output_records = []
|
||
|
|
if args.dedup == "best":
|
||
|
|
# Keep only the highest-scoring response per probe
|
||
|
|
for probe_id, recs in by_probe.items():
|
||
|
|
best = max(recs, key=lambda r: r["meta"]["lek_score"])
|
||
|
|
output_records.append(best)
|
||
|
|
elif args.dedup == "first":
|
||
|
|
# Keep first passing response per probe
|
||
|
|
for probe_id, recs in by_probe.items():
|
||
|
|
output_records.append(recs[0])
|
||
|
|
else: # all
|
||
|
|
output_records = records
|
||
|
|
|
||
|
|
# Sort by probe ID for reproducibility
|
||
|
|
output_records.sort(key=lambda r: r["meta"]["probe_id"])
|
||
|
|
|
||
|
|
# Write clean training data
|
||
|
|
with open(args.output, "w") as out:
|
||
|
|
for r in output_records:
|
||
|
|
out.write(json.dumps(r["training"]) + "\n")
|
||
|
|
|
||
|
|
print(f"Wrote {len(output_records)} training examples to {args.output}", file=sys.stderr)
|
||
|
|
|
||
|
|
if args.stats:
|
||
|
|
scores = [r["meta"]["lek_score"] for r in output_records]
|
||
|
|
categories = defaultdict(list)
|
||
|
|
for r in output_records:
|
||
|
|
categories[r["meta"]["category"]].append(r["meta"]["lek_score"])
|
||
|
|
|
||
|
|
print(f"\n=== Training Data Statistics ===", file=sys.stderr)
|
||
|
|
print(f"Examples: {len(scores)}", file=sys.stderr)
|
||
|
|
print(f"Probes: {len(by_probe)}", file=sys.stderr)
|
||
|
|
if not scores:
|
||
|
|
print(f"No examples passed threshold — try lowering --threshold", file=sys.stderr)
|
||
|
|
return
|
||
|
|
print(f"Avg score: {sum(scores)/len(scores):.2f}", file=sys.stderr)
|
||
|
|
print(f"Min score: {min(scores):.2f}", file=sys.stderr)
|
||
|
|
print(f"Max score: {max(scores):.2f}", file=sys.stderr)
|
||
|
|
|
||
|
|
print(f"\nPer-category:", file=sys.stderr)
|
||
|
|
for cat in sorted(categories.keys()):
|
||
|
|
cat_scores = categories[cat]
|
||
|
|
print(f" {cat}: {len(cat_scores)} examples, avg {sum(cat_scores)/len(cat_scores):.2f}", file=sys.stderr)
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = argparse.ArgumentParser(description="Extract training data from distillation output")
|
||
|
|
parser.add_argument("--input", required=True, help="Raw distillation JSONL")
|
||
|
|
parser.add_argument("--output", required=True, help="Clean training JSONL")
|
||
|
|
parser.add_argument("--dedup", default="best", choices=["best", "all", "first"],
|
||
|
|
help="Dedup strategy: best (highest score per probe), all, first")
|
||
|
|
parser.add_argument("--stats", action="store_true", help="Print statistics")
|
||
|
|
args = parser.parse_args()
|
||
|
|
extract(args)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|