Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.
Co-Authored-By: Virgil <virgil@lethean.io>
122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Distill sandwich training data for ethics phases (P0, P2).
|
|
|
|
Assembles: [LEK-1 kernel JSON] + \n\n + [Probe] + \n\n + [LEK-1-Sig]
|
|
as a single user message, then generates assistant response via local model.
|
|
|
|
Usage:
|
|
PYTHONUNBUFFERED=1 python3 scripts/distill_sandwich.py \
|
|
--model data/models/gemma3/4b \
|
|
--probes training/lem/eval/test-200.json \
|
|
--kernel data/kernels/lek-1-kernel.json \
|
|
--sig data/kernels/lek-1-sig.txt \
|
|
--output training/lem/ethics/p2-distilled.jsonl \
|
|
--valid-ratio 0.2
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
|
|
def load_model(model_path):
|
|
from mlx_lm import load
|
|
print(f"Loading model from {model_path}...", file=sys.stderr)
|
|
model, tokenizer = load(str(model_path))
|
|
print("Model loaded.", file=sys.stderr)
|
|
return model, tokenizer
|
|
|
|
def generate_response(model, tokenizer, messages, max_tokens=512):
|
|
from mlx_lm import generate
|
|
prompt = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True
|
|
)
|
|
response = generate(
|
|
model, tokenizer, prompt=prompt, max_tokens=max_tokens, verbose=False
|
|
)
|
|
return response.strip()
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Distill sandwich training data")
|
|
parser.add_argument("--model", required=True, help="Path to MLX model")
|
|
parser.add_argument("--probes", required=True, help="Path to probes JSON")
|
|
parser.add_argument("--kernel", required=True, help="Path to LEK kernel JSON")
|
|
parser.add_argument("--sig", required=True, help="Path to LEK sig text")
|
|
parser.add_argument("--output", required=True, help="Output JSONL path")
|
|
parser.add_argument("--valid-ratio", type=float, default=0.2, help="Validation split ratio")
|
|
parser.add_argument("--max-tokens", type=int, default=512, help="Max response tokens")
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
args = parser.parse_args()
|
|
|
|
random.seed(args.seed)
|
|
|
|
# Load kernel and sig
|
|
kernel_text = Path(args.kernel).read_text().strip()
|
|
sig_text = Path(args.sig).read_text().strip()
|
|
|
|
# Load probes
|
|
with open(args.probes) as f:
|
|
probes = json.load(f)
|
|
print(f"Loaded {len(probes)} probes", file=sys.stderr)
|
|
|
|
# Load model
|
|
model, tokenizer = load_model(args.model)
|
|
|
|
# Distill
|
|
output_path = Path(args.output)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
results = []
|
|
start = time.time()
|
|
|
|
for i, probe in enumerate(probes):
|
|
probe_text = probe["prompt"]
|
|
probe_id = probe.get("id", f"P{i:03d}")
|
|
|
|
# Assemble sandwich: kernel + \n\n + probe + \n\n + sig
|
|
sandwich = f"{kernel_text}\n\n{probe_text}\n\n{sig_text}"
|
|
|
|
messages = [{"role": "user", "content": sandwich}]
|
|
response = generate_response(model, tokenizer, messages, max_tokens=args.max_tokens)
|
|
|
|
example = {
|
|
"messages": [
|
|
{"role": "user", "content": sandwich},
|
|
{"role": "assistant", "content": response}
|
|
]
|
|
}
|
|
results.append(example)
|
|
|
|
elapsed = time.time() - start
|
|
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
|
eta = (len(probes) - i - 1) / rate if rate > 0 else 0
|
|
print(f" [{i+1}/{len(probes)}] {probe_id} — {len(response)} chars ({rate:.2f}/s, ETA {eta:.0f}s)", file=sys.stderr)
|
|
|
|
# Shuffle and split
|
|
random.shuffle(results)
|
|
valid_count = max(1, int(len(results) * args.valid_ratio))
|
|
valid_set = results[:valid_count]
|
|
train_set = results[valid_count:]
|
|
|
|
# Write train
|
|
train_path = output_path
|
|
with open(train_path, "w") as f:
|
|
for ex in train_set:
|
|
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
|
|
|
# Write valid
|
|
valid_path = output_path.with_name(output_path.stem.replace("train", "valid") + output_path.suffix)
|
|
if "train" not in output_path.stem:
|
|
valid_path = output_path.with_name(output_path.stem + "-valid" + output_path.suffix)
|
|
with open(valid_path, "w") as f:
|
|
for ex in valid_set:
|
|
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
|
|
|
|
print(f"\nDone: {len(train_set)} train → {train_path}", file=sys.stderr)
|
|
print(f" {len(valid_set)} valid → {valid_path}", file=sys.stderr)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|