Go lem CLI (stdlib + DuckDB) replaces scattered Python scripts: - score: heuristic regex + LLM-as-judge scoring - probe: generate responses then score - compare: diff two score files - status: InfluxDB training/generation progress - export: golden set to training JSONL splits - expand: distributed expansion via API + InfluxDB coordination New scripts from Feb 14 creative session: - scoring_agent.py: ROCm daemon that auto-scores checkpoints - probes.py: 23 binary pass/fail capability probes - convert_adapter.py: MLX to PEFT adapter conversion - score_r1_capability.py: DeepSeek R1 checkpoint scoring - lek_content_scorer.py: 6-dimension ethics content scorer - lem_train_15k.py: InfluxDB-coordinated training script - pipeline.py: DuckDB pipeline (seeds, golden set, expansion) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
160 lines
5.2 KiB
Python
160 lines
5.2 KiB
Python
"""MLX → PEFT adapter conversion.
|
||
|
||
MLX LoRA adapters use different tensor naming and layout than HuggingFace PEFT.
|
||
This module handles the three required transformations:
|
||
|
||
1. Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
|
||
2. Transpose: MLX stores (in, rank) / (rank, out), PEFT expects (rank, in) / (out, rank)
|
||
3. Config generation: adapter_config.json with lora_alpha = scale × rank
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import tempfile
|
||
|
||
import torch
|
||
from safetensors.torch import load_file, save_file
|
||
|
||
|
||
def rename_key(mlx_key: str) -> str:
|
||
"""Rename an MLX tensor key to PEFT format.
|
||
|
||
model.layers.12.self_attn.q_proj.lora_a
|
||
→ base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight
|
||
"""
|
||
key = mlx_key
|
||
# lora_a → lora_A.default.weight, lora_b → lora_B.default.weight
|
||
key = re.sub(r'\.lora_a$', '.lora_A.default.weight', key)
|
||
key = re.sub(r'\.lora_b$', '.lora_B.default.weight', key)
|
||
# Prepend base_model.model.
|
||
key = "base_model.model." + key
|
||
return key
|
||
|
||
|
||
def convert_mlx_to_peft(
|
||
mlx_safetensors_path: str,
|
||
mlx_config_path: str,
|
||
output_dir: str,
|
||
base_model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
||
) -> str:
|
||
"""Convert an MLX LoRA adapter to PEFT format.
|
||
|
||
Args:
|
||
mlx_safetensors_path: Path to MLX .safetensors file
|
||
mlx_config_path: Path to MLX adapter_config.json
|
||
output_dir: Directory to write PEFT adapter files
|
||
base_model_name: HuggingFace model ID for config
|
||
|
||
Returns:
|
||
Path to the output directory
|
||
"""
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# Load MLX tensors
|
||
mlx_tensors = load_file(mlx_safetensors_path)
|
||
|
||
# Rename and transpose
|
||
peft_tensors = {}
|
||
for mlx_key, tensor in mlx_tensors.items():
|
||
peft_key = rename_key(mlx_key)
|
||
# Transpose: MLX (in, rank) → PEFT (rank, in) for both A and B
|
||
peft_tensors[peft_key] = tensor.T.contiguous()
|
||
|
||
# Save PEFT safetensors
|
||
save_file(peft_tensors, os.path.join(output_dir, "adapter_model.safetensors"))
|
||
|
||
# Read MLX config for LoRA parameters
|
||
with open(mlx_config_path) as f:
|
||
mlx_config = json.load(f)
|
||
|
||
lora_params = mlx_config.get("lora_parameters", {})
|
||
rank = lora_params.get("rank", 8)
|
||
scale = lora_params.get("scale", 20.0)
|
||
dropout = lora_params.get("dropout", 0.0)
|
||
|
||
# Determine target modules from tensor keys
|
||
modules = set()
|
||
for k in mlx_tensors:
|
||
m = re.match(r'model\.layers\.\d+\.(.*?)\.lora_[ab]$', k)
|
||
if m:
|
||
# e.g. "self_attn.q_proj" → extract just the leaf: "q_proj"
|
||
full_path = m.group(1)
|
||
modules.add(full_path.split(".")[-1])
|
||
|
||
# Determine layer range
|
||
layers = set()
|
||
for k in mlx_tensors:
|
||
m = re.search(r'layers\.(\d+)', k)
|
||
if m:
|
||
layers.add(int(m.group(1)))
|
||
|
||
# Write PEFT adapter_config.json
|
||
peft_config = {
|
||
"auto_mapping": None,
|
||
"base_model_name_or_path": base_model_name,
|
||
"bias": "none",
|
||
"fan_in_fan_out": False,
|
||
"inference_mode": True,
|
||
"init_lora_weights": True,
|
||
"layers_pattern": None,
|
||
"layers_to_transform": sorted(layers),
|
||
"lora_alpha": scale * rank, # MLX scale × rank = PEFT alpha
|
||
"lora_dropout": dropout,
|
||
"modules_to_save": None,
|
||
"peft_type": "LORA",
|
||
"r": rank,
|
||
"revision": None,
|
||
"target_modules": sorted(modules),
|
||
"task_type": "CAUSAL_LM",
|
||
}
|
||
|
||
with open(os.path.join(output_dir, "adapter_config.json"), "w") as f:
|
||
json.dump(peft_config, f, indent=2)
|
||
|
||
return output_dir
|
||
|
||
|
||
def convert_checkpoint(
|
||
adapter_dir: str,
|
||
checkpoint_file: str,
|
||
work_dir: str,
|
||
base_model_name: str = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
||
) -> str:
|
||
"""Convenience: convert a single MLX checkpoint file to a PEFT adapter directory.
|
||
|
||
Args:
|
||
adapter_dir: Directory containing MLX adapter files and config
|
||
checkpoint_file: Filename like "0000050_adapters.safetensors"
|
||
work_dir: Working directory for temporary PEFT output
|
||
base_model_name: HuggingFace model ID
|
||
|
||
Returns:
|
||
Path to PEFT adapter directory
|
||
"""
|
||
safetensors_path = os.path.join(adapter_dir, checkpoint_file)
|
||
config_path = os.path.join(adapter_dir, "adapter_config.json")
|
||
|
||
# Use checkpoint iteration as subdirectory name
|
||
iter_match = re.search(r'(\d+)', checkpoint_file)
|
||
iter_name = iter_match.group(1) if iter_match else "unknown"
|
||
output_dir = os.path.join(work_dir, f"peft_{iter_name}")
|
||
|
||
return convert_mlx_to_peft(safetensors_path, config_path, output_dir, base_model_name)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import sys
|
||
if len(sys.argv) < 3:
|
||
print("Usage: python convert_adapter.py <mlx_safetensors> <mlx_config> [output_dir]")
|
||
print("Example: python convert_adapter.py 0000050_adapters.safetensors adapter_config.json ./peft_out")
|
||
sys.exit(1)
|
||
|
||
sf_path = sys.argv[1]
|
||
cfg_path = sys.argv[2]
|
||
out_dir = sys.argv[3] if len(sys.argv) > 3 else "./peft_output"
|
||
|
||
result = convert_mlx_to_peft(sf_path, cfg_path, out_dir)
|
||
print(f"Converted to: {result}")
|
||
print(f"Files: {os.listdir(result)}")
|