- `lem parquet` — export JSONL training splits to Parquet (parquet-go) - `lem publish` — push Parquet files to HuggingFace dataset repo - `lem metrics` — push DuckDB golden set stats to InfluxDB - `lem convert` — MLX LoRA adapter → HuggingFace PEFT format (pure Go safetensors read/write/transpose, no PyTorch needed) Dependencies added: parquet-go, go-huggingface, go-rocm, go-pytorch, gotch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
198 lines
5.7 KiB
Go
198 lines
5.7 KiB
Go
package lem
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"math"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
)
|
|
|
|
func TestRenameMLXKey(t *testing.T) {
|
|
tests := []struct {
|
|
input string
|
|
want string
|
|
}{
|
|
{
|
|
"model.layers.12.self_attn.q_proj.lora_a",
|
|
"base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight",
|
|
},
|
|
{
|
|
"model.layers.0.self_attn.v_proj.lora_b",
|
|
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight",
|
|
},
|
|
{
|
|
"model.layers.5.mlp.gate_proj.lora_a",
|
|
"base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
got := renameMLXKey(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("renameMLXKey(%q) = %q, want %q", tt.input, got, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTransposeFloat32(t *testing.T) {
|
|
// 2x3 matrix: [[1, 2, 3], [4, 5, 6]]
|
|
data := make([]byte, 2*3*4)
|
|
for i, v := range []float32{1, 2, 3, 4, 5, 6} {
|
|
binary.LittleEndian.PutUint32(data[i*4:], math.Float32bits(v))
|
|
}
|
|
|
|
result := transposeFloat32(data, 2, 3)
|
|
|
|
// Expected: 3x2 matrix: [[1, 4], [2, 5], [3, 6]]
|
|
expected := []float32{1, 4, 2, 5, 3, 6}
|
|
for i, want := range expected {
|
|
got := math.Float32frombits(binary.LittleEndian.Uint32(result[i*4:]))
|
|
if got != want {
|
|
t.Errorf("result[%d] = %f, want %f", i, got, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConvertMLXtoPEFT(t *testing.T) {
|
|
dir := t.TempDir()
|
|
|
|
// Create a minimal MLX safetensors file with one lora_a and one lora_b tensor.
|
|
// Shape: lora_a is (in=4, rank=2), lora_b is (rank=2, out=4)
|
|
tensors := map[string]safetensorsTensorInfo{
|
|
"model.layers.0.self_attn.q_proj.lora_a": {Dtype: "F32", Shape: []int{4, 2}},
|
|
"model.layers.0.self_attn.q_proj.lora_b": {Dtype: "F32", Shape: []int{2, 4}},
|
|
}
|
|
|
|
// Create tensor data: 4x2=8 floats and 2x4=8 floats.
|
|
loraAData := make([]byte, 4*2*4)
|
|
for i := 0; i < 8; i++ {
|
|
binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1)))
|
|
}
|
|
loraBData := make([]byte, 2*4*4)
|
|
for i := 0; i < 8; i++ {
|
|
binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i)))
|
|
}
|
|
|
|
tensorData := make(map[string][]byte)
|
|
tensorData["model.layers.0.self_attn.q_proj.lora_a"] = loraAData
|
|
tensorData["model.layers.0.self_attn.q_proj.lora_b"] = loraBData
|
|
|
|
sfPath := filepath.Join(dir, "adapters.safetensors")
|
|
if err := writeSafetensors(sfPath, tensors, tensorData); err != nil {
|
|
t.Fatalf("write test safetensors: %v", err)
|
|
}
|
|
|
|
// Create MLX config.
|
|
mlxConfig := map[string]interface{}{
|
|
"lora_parameters": map[string]interface{}{
|
|
"rank": 8,
|
|
"scale": 20.0,
|
|
"dropout": 0.0,
|
|
},
|
|
}
|
|
cfgData, _ := json.Marshal(mlxConfig)
|
|
cfgPath := filepath.Join(dir, "adapter_config.json")
|
|
os.WriteFile(cfgPath, cfgData, 0644)
|
|
|
|
// Convert.
|
|
outputDir := filepath.Join(dir, "peft_output")
|
|
if err := convertMLXtoPEFT(sfPath, cfgPath, outputDir, "test-model"); err != nil {
|
|
t.Fatalf("convert: %v", err)
|
|
}
|
|
|
|
// Check output files exist.
|
|
if _, err := os.Stat(filepath.Join(outputDir, "adapter_model.safetensors")); err != nil {
|
|
t.Error("missing adapter_model.safetensors")
|
|
}
|
|
if _, err := os.Stat(filepath.Join(outputDir, "adapter_config.json")); err != nil {
|
|
t.Error("missing adapter_config.json")
|
|
}
|
|
|
|
// Read and verify PEFT config.
|
|
peftCfgData, err := os.ReadFile(filepath.Join(outputDir, "adapter_config.json"))
|
|
if err != nil {
|
|
t.Fatalf("read peft config: %v", err)
|
|
}
|
|
|
|
var peftConfig map[string]interface{}
|
|
if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil {
|
|
t.Fatalf("parse peft config: %v", err)
|
|
}
|
|
|
|
if peftConfig["peft_type"] != "LORA" {
|
|
t.Errorf("peft_type = %v, want LORA", peftConfig["peft_type"])
|
|
}
|
|
if peftConfig["base_model_name_or_path"] != "test-model" {
|
|
t.Errorf("base_model = %v, want test-model", peftConfig["base_model_name_or_path"])
|
|
}
|
|
|
|
// Check that lora_alpha = scale * rank = 20 * 8 = 160.
|
|
if alpha, ok := peftConfig["lora_alpha"].(float64); !ok || alpha != 160 {
|
|
t.Errorf("lora_alpha = %v, want 160", peftConfig["lora_alpha"])
|
|
}
|
|
|
|
// Verify converted safetensors has PEFT-format keys.
|
|
peftTensors, _, err := readSafetensors(filepath.Join(outputDir, "adapter_model.safetensors"))
|
|
if err != nil {
|
|
t.Fatalf("read peft safetensors: %v", err)
|
|
}
|
|
|
|
expectedKeys := []string{
|
|
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight",
|
|
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight",
|
|
}
|
|
for _, k := range expectedKeys {
|
|
if _, ok := peftTensors[k]; !ok {
|
|
t.Errorf("missing expected PEFT key: %s", k)
|
|
}
|
|
}
|
|
|
|
// Verify shapes are transposed: lora_a (4,2) → (2,4), lora_b (2,4) → (4,2).
|
|
loraAInfo := peftTensors["base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"]
|
|
if loraAInfo.Shape[0] != 2 || loraAInfo.Shape[1] != 4 {
|
|
t.Errorf("lora_A shape = %v, want [2, 4]", loraAInfo.Shape)
|
|
}
|
|
}
|
|
|
|
func TestReadWriteSafetensorsRoundtrip(t *testing.T) {
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "test.safetensors")
|
|
|
|
original := map[string]safetensorsTensorInfo{
|
|
"weight_a": {Dtype: "F32", Shape: []int{2, 3}},
|
|
}
|
|
data := map[string][]byte{
|
|
"weight_a": make([]byte, 2*3*4),
|
|
}
|
|
for i := 0; i < 6; i++ {
|
|
binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i)))
|
|
}
|
|
|
|
if err := writeSafetensors(path, original, data); err != nil {
|
|
t.Fatalf("write: %v", err)
|
|
}
|
|
|
|
readTensors, readData, err := readSafetensors(path)
|
|
if err != nil {
|
|
t.Fatalf("read: %v", err)
|
|
}
|
|
|
|
if len(readTensors) != 1 {
|
|
t.Fatalf("expected 1 tensor, got %d", len(readTensors))
|
|
}
|
|
|
|
info := readTensors["weight_a"]
|
|
if info.Dtype != "F32" {
|
|
t.Errorf("dtype = %s, want F32", info.Dtype)
|
|
}
|
|
if info.Shape[0] != 2 || info.Shape[1] != 3 {
|
|
t.Errorf("shape = %v, want [2, 3]", info.Shape)
|
|
}
|
|
|
|
got := getTensorData(info, readData)
|
|
if len(got) != 24 {
|
|
t.Errorf("data length = %d, want 24", len(got))
|
|
}
|
|
}
|