1
0
Fork 0
forked from lthn/LEM
LEM/pkg/lem/convert_test.go
Claude 4eaf1bfb39
feat: add parquet, publish, metrics, convert commands
- `lem parquet` — export JSONL training splits to Parquet (parquet-go)
- `lem publish` — push Parquet files to HuggingFace dataset repo
- `lem metrics` — push DuckDB golden set stats to InfluxDB
- `lem convert` — MLX LoRA adapter → HuggingFace PEFT format
  (pure Go safetensors read/write/transpose, no PyTorch needed)

Dependencies added: parquet-go, go-huggingface, go-rocm, go-pytorch, gotch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:05:08 +00:00

198 lines
5.7 KiB
Go

package lem
import (
"encoding/binary"
"encoding/json"
"math"
"os"
"path/filepath"
"testing"
)
func TestRenameMLXKey(t *testing.T) {
tests := []struct {
input string
want string
}{
{
"model.layers.12.self_attn.q_proj.lora_a",
"base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight",
},
{
"model.layers.0.self_attn.v_proj.lora_b",
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight",
},
{
"model.layers.5.mlp.gate_proj.lora_a",
"base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight",
},
}
for _, tt := range tests {
got := renameMLXKey(tt.input)
if got != tt.want {
t.Errorf("renameMLXKey(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestTransposeFloat32(t *testing.T) {
// 2x3 matrix: [[1, 2, 3], [4, 5, 6]]
data := make([]byte, 2*3*4)
for i, v := range []float32{1, 2, 3, 4, 5, 6} {
binary.LittleEndian.PutUint32(data[i*4:], math.Float32bits(v))
}
result := transposeFloat32(data, 2, 3)
// Expected: 3x2 matrix: [[1, 4], [2, 5], [3, 6]]
expected := []float32{1, 4, 2, 5, 3, 6}
for i, want := range expected {
got := math.Float32frombits(binary.LittleEndian.Uint32(result[i*4:]))
if got != want {
t.Errorf("result[%d] = %f, want %f", i, got, want)
}
}
}
func TestConvertMLXtoPEFT(t *testing.T) {
dir := t.TempDir()
// Create a minimal MLX safetensors file with one lora_a and one lora_b tensor.
// Shape: lora_a is (in=4, rank=2), lora_b is (rank=2, out=4)
tensors := map[string]safetensorsTensorInfo{
"model.layers.0.self_attn.q_proj.lora_a": {Dtype: "F32", Shape: []int{4, 2}},
"model.layers.0.self_attn.q_proj.lora_b": {Dtype: "F32", Shape: []int{2, 4}},
}
// Create tensor data: 4x2=8 floats and 2x4=8 floats.
loraAData := make([]byte, 4*2*4)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1)))
}
loraBData := make([]byte, 2*4*4)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i)))
}
tensorData := make(map[string][]byte)
tensorData["model.layers.0.self_attn.q_proj.lora_a"] = loraAData
tensorData["model.layers.0.self_attn.q_proj.lora_b"] = loraBData
sfPath := filepath.Join(dir, "adapters.safetensors")
if err := writeSafetensors(sfPath, tensors, tensorData); err != nil {
t.Fatalf("write test safetensors: %v", err)
}
// Create MLX config.
mlxConfig := map[string]interface{}{
"lora_parameters": map[string]interface{}{
"rank": 8,
"scale": 20.0,
"dropout": 0.0,
},
}
cfgData, _ := json.Marshal(mlxConfig)
cfgPath := filepath.Join(dir, "adapter_config.json")
os.WriteFile(cfgPath, cfgData, 0644)
// Convert.
outputDir := filepath.Join(dir, "peft_output")
if err := convertMLXtoPEFT(sfPath, cfgPath, outputDir, "test-model"); err != nil {
t.Fatalf("convert: %v", err)
}
// Check output files exist.
if _, err := os.Stat(filepath.Join(outputDir, "adapter_model.safetensors")); err != nil {
t.Error("missing adapter_model.safetensors")
}
if _, err := os.Stat(filepath.Join(outputDir, "adapter_config.json")); err != nil {
t.Error("missing adapter_config.json")
}
// Read and verify PEFT config.
peftCfgData, err := os.ReadFile(filepath.Join(outputDir, "adapter_config.json"))
if err != nil {
t.Fatalf("read peft config: %v", err)
}
var peftConfig map[string]interface{}
if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil {
t.Fatalf("parse peft config: %v", err)
}
if peftConfig["peft_type"] != "LORA" {
t.Errorf("peft_type = %v, want LORA", peftConfig["peft_type"])
}
if peftConfig["base_model_name_or_path"] != "test-model" {
t.Errorf("base_model = %v, want test-model", peftConfig["base_model_name_or_path"])
}
// Check that lora_alpha = scale * rank = 20 * 8 = 160.
if alpha, ok := peftConfig["lora_alpha"].(float64); !ok || alpha != 160 {
t.Errorf("lora_alpha = %v, want 160", peftConfig["lora_alpha"])
}
// Verify converted safetensors has PEFT-format keys.
peftTensors, _, err := readSafetensors(filepath.Join(outputDir, "adapter_model.safetensors"))
if err != nil {
t.Fatalf("read peft safetensors: %v", err)
}
expectedKeys := []string{
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight",
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight",
}
for _, k := range expectedKeys {
if _, ok := peftTensors[k]; !ok {
t.Errorf("missing expected PEFT key: %s", k)
}
}
// Verify shapes are transposed: lora_a (4,2) → (2,4), lora_b (2,4) → (4,2).
loraAInfo := peftTensors["base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"]
if loraAInfo.Shape[0] != 2 || loraAInfo.Shape[1] != 4 {
t.Errorf("lora_A shape = %v, want [2, 4]", loraAInfo.Shape)
}
}
func TestReadWriteSafetensorsRoundtrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
original := map[string]safetensorsTensorInfo{
"weight_a": {Dtype: "F32", Shape: []int{2, 3}},
}
data := map[string][]byte{
"weight_a": make([]byte, 2*3*4),
}
for i := 0; i < 6; i++ {
binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i)))
}
if err := writeSafetensors(path, original, data); err != nil {
t.Fatalf("write: %v", err)
}
readTensors, readData, err := readSafetensors(path)
if err != nil {
t.Fatalf("read: %v", err)
}
if len(readTensors) != 1 {
t.Fatalf("expected 1 tensor, got %d", len(readTensors))
}
info := readTensors["weight_a"]
if info.Dtype != "F32" {
t.Errorf("dtype = %s, want F32", info.Dtype)
}
if info.Shape[0] != 2 || info.Shape[1] != 3 {
t.Errorf("shape = %v, want [2, 3]", info.Shape)
}
got := getTensorData(info, readData)
if len(got) != 24 {
t.Errorf("data length = %d, want 24", len(got))
}
}