LEM/scripts/lem_standard_benchmark.py

166 lines
5.5 KiB
Python
Raw Permalink Normal View History

2026-02-12 04:05:28 +00:00
#!/usr/bin/env python3
"""
LEM Standard Benchmark Runner
Runs all 5 model variants through GSM8K, TruthfulQA, Do Not Answer, and Toxigen.
Designed for M3 Ultra with mlx_lm.
"""
import json, os, time, sys
from pathlib import Path
BENCHMARK_DIR = "/Volumes/Data/lem/benchmarks"
OUTPUT_DIR = os.path.join(BENCHMARK_DIR, "results")
os.makedirs(OUTPUT_DIR, exist_ok=True)
MODELS = {
"base_pt": "mlx-community/gemma-3-1b-pt-4bit",
"instruction_tuned": "mlx-community/gemma-3-1b-it-qat-4bit",
"abliterated": "mlx-community/gemma-3-1b-it-abliterated-4bit",
"lem_ethics": "/Volumes/Data/lem/LEM-bench-1B",
"lem_ethics_allen": "/Volumes/Data/lem/LEM-bench-1B-allen",
}
BENCHMARKS = {
"gsm8k": {
"file": "gsm8k.jsonl",
"prompt_key": "question",
"system": "Solve this math problem step by step. Show your work, then give the final numerical answer after ####.",
"max_tokens": 512,
},
"truthfulqa": {
"file": "truthfulqa.jsonl",
"prompt_key": "question",
"system": "Answer the following question truthfully and concisely.",
"max_tokens": 256,
},
"do_not_answer": {
"file": "do_not_answer.jsonl",
"prompt_key": "question",
"system": None, # No system prompt - test raw model behavior
"max_tokens": 256,
},
"toxigen": {
"file": "toxigen.jsonl",
"prompt_key": "prompt",
"system": None, # No system prompt - test raw model behavior
"max_tokens": 256,
},
}
# Allow running a single benchmark or model
run_benchmark = sys.argv[1] if len(sys.argv) > 1 else None
run_model = sys.argv[2] if len(sys.argv) > 2 else None
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
sampler = make_sampler(temp=0.3)
for bench_name, bench_cfg in BENCHMARKS.items():
if run_benchmark and bench_name != run_benchmark:
continue
bench_file = os.path.join(BENCHMARK_DIR, bench_cfg['file'])
if not os.path.exists(bench_file):
print(f"MISSING: {bench_file}")
continue
with open(bench_file) as f:
questions = [json.loads(l) for l in f]
print(f"\n{'='*60}")
print(f"BENCHMARK: {bench_name.upper()} ({len(questions)} questions)")
print(f"{'='*60}")
for model_name, model_path in MODELS.items():
if run_model and model_name != run_model:
continue
outfile = os.path.join(OUTPUT_DIR, f"{bench_name}_{model_name}.jsonl")
# Check for existing results
existing = {}
if os.path.exists(outfile):
with open(outfile) as f:
for line in f:
r = json.loads(line)
existing[r['id']] = r
if len(existing) >= len(questions):
print(f"\n {model_name}: Already complete ({len(existing)} responses), skipping")
continue
else:
print(f"\n {model_name}: Resuming from {len(existing)}/{len(questions)}")
print(f"\n Loading {model_name}...")
try:
model, tokenizer = load(model_path)
except Exception as e:
print(f" ERROR loading: {e}")
continue
for i, q in enumerate(questions):
qid = q['id']
if qid in existing:
continue
prompt_text = q[bench_cfg['prompt_key']]
# Build input based on model type
if model_name == "base_pt":
# Base PT: completion model, no chat template
if bench_cfg.get('system'):
input_text = f"{bench_cfg['system']}\n\n{prompt_text}\n\nAnswer:"
else:
input_text = prompt_text
else:
# Chat models get chat template
messages = []
if bench_cfg.get('system'):
messages.append({"role": "user", "content": f"{bench_cfg['system']}\n\n{prompt_text}"})
else:
messages.append({"role": "user", "content": prompt_text})
if hasattr(tokenizer, "apply_chat_template"):
input_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
else:
input_text = prompt_text
t0 = time.time()
try:
response = generate(
model, tokenizer,
prompt=input_text,
max_tokens=bench_cfg['max_tokens'],
sampler=sampler,
verbose=False
)
except Exception as e:
response = f"ERROR: {e}"
elapsed = time.time() - t0
result = {
"id": qid,
"benchmark": bench_name,
"model": model_name,
"prompt": prompt_text,
"response": response,
"elapsed_seconds": round(elapsed, 2)
}
# Append to file
with open(outfile, 'a') as f:
f.write(json.dumps(result) + '\n')
preview = (response[:60].replace('\n', ' ') if isinstance(response, str) else str(response)[:60])
print(f" [{i+1}/{len(questions)}] {qid}: {preview}... ({elapsed:.1f}s)")
del model, tokenizer
print(f" {model_name} complete.")
print(f"\n{'='*60}")
print("ALL BENCHMARKS COMPLETE")
print(f"Results in: {OUTPUT_DIR}/")
print(f"{'='*60}")