#!/usr/bin/env python3 """ Distill sandwich training data for ethics phases (P0, P2). Assembles: [LEK-1 kernel JSON] + \n\n + [Probe] + \n\n + [LEK-1-Sig] as a single user message, then generates assistant response via local model. Usage: PYTHONUNBUFFERED=1 python3 scripts/distill_sandwich.py \ --model data/models/gemma3/4b \ --probes training/lem/eval/test-200.json \ --kernel data/kernels/lek-1-kernel.json \ --sig data/kernels/lek-1-sig.txt \ --output training/lem/ethics/p2-distilled.jsonl \ --valid-ratio 0.2 """ import argparse import json import random import sys import time from pathlib import Path def load_model(model_path): from mlx_lm import load print(f"Loading model from {model_path}...", file=sys.stderr) model, tokenizer = load(str(model_path)) print("Model loaded.", file=sys.stderr) return model, tokenizer def generate_response(model, tokenizer, messages, max_tokens=512): from mlx_lm import generate prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) response = generate( model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False ) return response.strip() def main(): parser = argparse.ArgumentParser(description="Distill sandwich training data") parser.add_argument("--model", required=True, help="Path to MLX model") parser.add_argument("--probes", required=True, help="Path to probes JSON") parser.add_argument("--kernel", required=True, help="Path to LEK kernel JSON") parser.add_argument("--sig", required=True, help="Path to LEK sig text") parser.add_argument("--output", required=True, help="Output JSONL path") parser.add_argument("--valid-ratio", type=float, default=0.2, help="Validation split ratio") parser.add_argument("--max-tokens", type=int, default=512, help="Max response tokens") parser.add_argument("--seed", type=int, default=42, help="Random seed") args = parser.parse_args() random.seed(args.seed) # Load kernel and sig kernel_text = Path(args.kernel).read_text().strip() sig_text = Path(args.sig).read_text().strip() # Load probes with open(args.probes) as f: probes = json.load(f) print(f"Loaded {len(probes)} probes", file=sys.stderr) # Load model model, tokenizer = load_model(args.model) # Distill output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) results = [] start = time.time() for i, probe in enumerate(probes): probe_text = probe["prompt"] probe_id = probe.get("id", f"P{i:03d}") # Assemble sandwich: kernel + \n\n + probe + \n\n + sig sandwich = f"{kernel_text}\n\n{probe_text}\n\n{sig_text}" messages = [{"role": "user", "content": sandwich}] response = generate_response(model, tokenizer, messages, max_tokens=args.max_tokens) example = { "messages": [ {"role": "user", "content": sandwich}, {"role": "assistant", "content": response} ] } results.append(example) elapsed = time.time() - start rate = (i + 1) / elapsed if elapsed > 0 else 0 eta = (len(probes) - i - 1) / rate if rate > 0 else 0 print(f" [{i+1}/{len(probes)}] {probe_id} — {len(response)} chars ({rate:.2f}/s, ETA {eta:.0f}s)", file=sys.stderr) # Shuffle and split random.shuffle(results) valid_count = max(1, int(len(results) * args.valid_ratio)) valid_set = results[:valid_count] train_set = results[valid_count:] # Write train train_path = output_path with open(train_path, "w") as f: for ex in train_set: f.write(json.dumps(ex, ensure_ascii=False) + "\n") # Write valid valid_path = output_path.with_name(output_path.stem.replace("train", "valid") + output_path.suffix) if "train" not in output_path.stem: valid_path = output_path.with_name(output_path.stem + "-valid" + output_path.suffix) with open(valid_path, "w") as f: for ex in valid_set: f.write(json.dumps(ex, ensure_ascii=False) + "\n") print(f"\nDone: {len(train_set)} train → {train_path}", file=sys.stderr) print(f" {len(valid_set)} valid → {valid_path}", file=sys.stderr) if __name__ == "__main__": main()