LEM/scripts/distill_sandwich.py
Snider f75458bce6 refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-22 21:00:17 +00:00

122 lines
4.3 KiB
Python

#!/usr/bin/env python3
"""
Distill sandwich training data for ethics phases (P0, P2).
Assembles: [LEK-1 kernel JSON] + \n\n + [Probe] + \n\n + [LEK-1-Sig]
as a single user message, then generates assistant response via local model.
Usage:
PYTHONUNBUFFERED=1 python3 scripts/distill_sandwich.py \
--model data/models/gemma3/4b \
--probes training/lem/eval/test-200.json \
--kernel data/kernels/lek-1-kernel.json \
--sig data/kernels/lek-1-sig.txt \
--output training/lem/ethics/p2-distilled.jsonl \
--valid-ratio 0.2
"""
import argparse
import json
import random
import sys
import time
from pathlib import Path
def load_model(model_path):
from mlx_lm import load
print(f"Loading model from {model_path}...", file=sys.stderr)
model, tokenizer = load(str(model_path))
print("Model loaded.", file=sys.stderr)
return model, tokenizer
def generate_response(model, tokenizer, messages, max_tokens=512):
from mlx_lm import generate
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
response = generate(
model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False
)
return response.strip()
def main():
parser = argparse.ArgumentParser(description="Distill sandwich training data")
parser.add_argument("--model", required=True, help="Path to MLX model")
parser.add_argument("--probes", required=True, help="Path to probes JSON")
parser.add_argument("--kernel", required=True, help="Path to LEK kernel JSON")
parser.add_argument("--sig", required=True, help="Path to LEK sig text")
parser.add_argument("--output", required=True, help="Output JSONL path")
parser.add_argument("--valid-ratio", type=float, default=0.2, help="Validation split ratio")
parser.add_argument("--max-tokens", type=int, default=512, help="Max response tokens")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
random.seed(args.seed)
# Load kernel and sig
kernel_text = Path(args.kernel).read_text().strip()
sig_text = Path(args.sig).read_text().strip()
# Load probes
with open(args.probes) as f:
probes = json.load(f)
print(f"Loaded {len(probes)} probes", file=sys.stderr)
# Load model
model, tokenizer = load_model(args.model)
# Distill
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
results = []
start = time.time()
for i, probe in enumerate(probes):
probe_text = probe["prompt"]
probe_id = probe.get("id", f"P{i:03d}")
# Assemble sandwich: kernel + \n\n + probe + \n\n + sig
sandwich = f"{kernel_text}\n\n{probe_text}\n\n{sig_text}"
messages = [{"role": "user", "content": sandwich}]
response = generate_response(model, tokenizer, messages, max_tokens=args.max_tokens)
example = {
"messages": [
{"role": "user", "content": sandwich},
{"role": "assistant", "content": response}
]
}
results.append(example)
elapsed = time.time() - start
rate = (i + 1) / elapsed if elapsed > 0 else 0
eta = (len(probes) - i - 1) / rate if rate > 0 else 0
print(f" [{i+1}/{len(probes)}] {probe_id}{len(response)} chars ({rate:.2f}/s, ETA {eta:.0f}s)", file=sys.stderr)
# Shuffle and split
random.shuffle(results)
valid_count = max(1, int(len(results) * args.valid_ratio))
valid_set = results[:valid_count]
train_set = results[valid_count:]
# Write train
train_path = output_path
with open(train_path, "w") as f:
for ex in train_set:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
# Write valid
valid_path = output_path.with_name(output_path.stem.replace("train", "valid") + output_path.suffix)
if "train" not in output_path.stem:
valid_path = output_path.with_name(output_path.stem + "-valid" + output_path.suffix)
with open(valid_path, "w") as f:
for ex in valid_set:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
print(f"\nDone: {len(train_set)} train → {train_path}", file=sys.stderr)
print(f" {len(valid_set)} valid → {valid_path}", file=sys.stderr)
if __name__ == "__main__":
main()