LEM/scripts/eval_adapter.py
Snider f75458bce6 refactor: apply go fix modernizers for Go 1.26
Automated fixes: interface{} → any, range-over-int, t.Context(),
wg.Go(), strings.SplitSeq, strings.Builder, slices.Contains,
maps helpers, min/max builtins.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-22 21:00:17 +00:00

186 lines
5.8 KiB
Python

#!/usr/bin/env python3
"""Evaluate a LoRA adapter by generating fresh responses on training probes.
Loads the base model + adapter, runs each probe in sandwich format,
and outputs scorer-compatible JSONL (training format with probe-only user content).
Usage:
python3 scripts/eval_adapter.py \
--model data/models/gemma3/4b \
--adapter /Volumes/Data/lem/adapters/gemma3-4b-v2 \
--data training/lem/model/gemma3/4b/train.jsonl \
--data training/lem/model/gemma3/4b/valid.jsonl \
--kernel data/kernels/lek-1-kernel.json \
--output /tmp/eval-p0-adapter.jsonl
"""
import argparse
import json
import re
import sys
import time
import mlx.core as mx
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
SIG = (
"Dream lofty dreams,\n"
"and as you dream,\n"
"so shall you become,\n"
"Dreams are the seedlings of reality.\n"
"- James Allen"
)
def extract_probe(user_content: str) -> str:
"""Strip kernel JSON and sig from a sandwich user message, return raw probe."""
# The sandwich format is: kernel_json + \n\n + probe + \n\n + sig
# Find the sig at the end and strip it
sig_idx = user_content.rfind("Dream lofty dreams,")
if sig_idx > 0:
without_sig = user_content[:sig_idx].rstrip()
else:
without_sig = user_content
# Find the end of the kernel JSON block
# The kernel starts with { and ends with } before the probe
# Look for the closing brace of the top-level JSON object
brace_depth = 0
json_end = -1
in_string = False
escape_next = False
for i, ch in enumerate(without_sig):
if escape_next:
escape_next = False
continue
if ch == '\\' and in_string:
escape_next = True
continue
if ch == '"' and not escape_next:
in_string = not in_string
continue
if in_string:
continue
if ch == '{':
brace_depth += 1
elif ch == '}':
brace_depth -= 1
if brace_depth == 0:
json_end = i
break
if json_end > 0:
probe = without_sig[json_end + 1:].strip()
else:
# No JSON found — might already be a plain probe
probe = without_sig.strip()
return probe
def build_sandwich(kernel_json: str, probe: str) -> str:
"""Build a LEK-1 sandwich: kernel + probe + sig."""
return f"{kernel_json}\n\n{probe}\n\n{SIG}"
def main():
parser = argparse.ArgumentParser(description="Evaluate LoRA adapter on training probes")
parser.add_argument("--model", required=True, help="Path to base model")
parser.add_argument("--adapter", required=True, help="Path to adapter weights")
parser.add_argument("--data", action="append", required=True, help="Training JSONL file(s)")
parser.add_argument("--kernel", required=True, help="Path to LEK-1 kernel JSON file")
parser.add_argument("--output", required=True, help="Output JSONL path")
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens per response")
parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
args = parser.parse_args()
# Load kernel
with open(args.kernel) as f:
kernel_json = f.read().strip()
# Load all probes from data files
probes = []
for data_file in args.data:
with open(data_file) as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
rec = json.loads(line)
msgs = rec.get("messages", rec.get("training", {}).get("messages", []))
if len(msgs) < 2:
continue
user_content = msgs[0]["content"]
probe_text = extract_probe(user_content)
probes.append({
"probe": probe_text,
"source": data_file,
"line": line_num,
})
print(f"Loaded {len(probes)} probes from {len(args.data)} file(s)")
# Load model + adapter
print(f"Loading model: {args.model}")
print(f"Loading adapter: {args.adapter}")
model, tokenizer = load(args.model, adapter_path=args.adapter)
print("Model loaded.")
# Generate responses
results = []
t0 = time.time()
for i, p in enumerate(probes):
sandwich = build_sandwich(kernel_json, p["probe"])
# Apply chat template
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": sandwich}],
tokenize=False,
add_generation_prompt=True,
)
sampler = make_sampler(temp=args.temp)
response = generate(
model,
tokenizer,
prompt=prompt,
max_tokens=args.max_tokens,
sampler=sampler,
)
# Build scorer-compatible record (probe only, no sandwich)
record = {
"type": "training",
"training": {
"messages": [
{"role": "user", "content": p["probe"]},
{"role": "assistant", "content": response},
]
},
"meta": {
"probe_id": f"P0-{i:03d}",
"category": "ethics",
"lek_score": 0,
}
}
results.append(record)
elapsed = time.time() - t0
rate = (i + 1) / elapsed if elapsed > 0 else 0
print(f" [{i+1}/{len(probes)}] {len(response)} chars | {rate:.1f} probes/min | {p['probe'][:60]}...")
# Write output
with open(args.output, "w") as f:
for rec in results:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
elapsed = time.time() - t0
print(f"\nDone. {len(results)} responses in {elapsed:.0f}s → {args.output}")
if __name__ == "__main__":
main()