LEM/pkg/lem/convert.go
Snider 56eda1a081 refactor: migrate all 25 commands from passthrough to cobra framework
Replace passthrough() + stdlib flag.FlagSet anti-pattern with proper
cobra integration. Every Run* function now takes a typed *Opts struct
and returns error. Flags registered via cli.StringFlag/IntFlag/etc.
Commands participate in Core lifecycle with full cobra flag parsing.

- 6 command groups: gen, score, data, export, infra, mon
- 25 commands converted, 0 passthrough() calls remain
- Delete passthrough() helper from lem.go
- Update export_test.go to use ExportOpts struct

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 03:32:53 +00:00

344 lines
9.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package lem
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
)
// ConvertOpts holds configuration for the MLX-to-PEFT conversion command.
type ConvertOpts struct {
Input string // Path to MLX .safetensors file (required)
Config string // Path to MLX adapter_config.json (required)
Output string // Output directory for PEFT adapter
BaseModel string // HuggingFace base model ID
}
// RunConvert is the CLI entry point for the convert command.
// Converts MLX LoRA adapters to HuggingFace PEFT format:
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
// - Transpose: MLX (in, rank) → PEFT (rank, in)
// - Config generation: adapter_config.json with lora_alpha = scale × rank
func RunConvert(cfg ConvertOpts) error {
if cfg.Input == "" || cfg.Config == "" {
return fmt.Errorf("--input and --config are required")
}
if err := convertMLXtoPEFT(cfg.Input, cfg.Config, cfg.Output, cfg.BaseModel); err != nil {
return fmt.Errorf("convert: %w", err)
}
fmt.Printf("Converted to: %s\n", cfg.Output)
return nil
}
var (
loraARe = regexp.MustCompile(`\.lora_a$`)
loraBRe = regexp.MustCompile(`\.lora_b$`)
layerRe = regexp.MustCompile(`layers\.(\d+)`)
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
)
// renameMLXKey converts an MLX tensor key to PEFT format.
func renameMLXKey(mlxKey string) string {
key := mlxKey
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
key = "base_model.model." + key
return key
}
// safetensorsHeader represents the header of a safetensors file.
type safetensorsHeader struct {
Metadata map[string]string `json:"__metadata__,omitempty"`
Tensors map[string]safetensorsTensorInfo `json:"-"`
}
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// readSafetensors reads a safetensors file and returns tensor name→data+info pairs.
func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, fmt.Errorf("read file: %w", err)
}
if len(data) < 8 {
return nil, nil, fmt.Errorf("file too small")
}
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
if 8+headerSize > len(data) {
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
}
headerJSON := data[8 : 8+headerSize]
tensorData := data[8+headerSize:]
// Parse header as a generic map since tensors are top-level keys.
var rawHeader map[string]json.RawMessage
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
return nil, nil, fmt.Errorf("parse header: %w", err)
}
tensors := make(map[string]safetensorsTensorInfo)
for key, raw := range rawHeader {
if key == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
}
tensors[key] = info
}
return tensors, tensorData, nil
}
// getTensorData extracts raw bytes for a tensor from the data section.
func getTensorData(info safetensorsTensorInfo, allData []byte) []byte {
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
}
// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
func transposeFloat32(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*4 {
return data // size mismatch, return as-is
}
result := make([]byte, len(data))
for r := range rows {
for c := range cols {
srcOff := (r*cols + c) * 4
dstOff := (c*rows + r) * 4
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
}
}
return result
}
// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
func transposeFloat16(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*2 {
return data
}
result := make([]byte, len(data))
for r := range rows {
for c := range cols {
srcOff := (r*cols + c) * 2
dstOff := (c*rows + r) * 2
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
}
}
return result
}
// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
func transposeBFloat16(data []byte, rows, cols int) []byte {
return transposeFloat16(data, rows, cols) // same element size
}
// writeSafetensors writes tensors to a safetensors file.
func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error {
// Sort keys for deterministic output.
keys := make([]string, 0, len(tensors))
for k := range tensors {
keys = append(keys, k)
}
sort.Strings(keys)
// Compute offsets.
offset := 0
updatedTensors := make(map[string]safetensorsTensorInfo)
for _, k := range keys {
info := tensors[k]
data := tensorData[k]
info.DataOffsets = [2]int{offset, offset + len(data)}
updatedTensors[k] = info
offset += len(data)
}
// Build header JSON.
headerMap := make(map[string]any)
for k, info := range updatedTensors {
headerMap[k] = info
}
headerJSON, err := json.Marshal(headerMap)
if err != nil {
return fmt.Errorf("marshal header: %w", err)
}
// Write file: 8-byte header size + header JSON + tensor data.
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
headerSizeBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
if _, err := f.Write(headerSizeBuf); err != nil {
return err
}
if _, err := f.Write(headerJSON); err != nil {
return err
}
for _, k := range keys {
if _, err := f.Write(tensorData[k]); err != nil {
return err
}
}
return nil
}
// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format.
func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
// Read MLX tensors.
tensors, tensorData, err := readSafetensors(safetensorsPath)
if err != nil {
return fmt.Errorf("read safetensors: %w", err)
}
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
// Rename and transpose tensors.
peftTensors := make(map[string]safetensorsTensorInfo)
peftData := make(map[string][]byte)
for mlxKey, info := range tensors {
peftKey := renameMLXKey(mlxKey)
data := getTensorData(info, tensorData)
// Transpose: swap shape and transpose data.
if len(info.Shape) == 2 {
rows, cols := info.Shape[0], info.Shape[1]
switch info.Dtype {
case "F32":
data = transposeFloat32(data, rows, cols)
case "F16":
data = transposeFloat16(data, rows, cols)
case "BF16":
data = transposeBFloat16(data, rows, cols)
}
info.Shape = []int{cols, rows}
}
peftTensors[peftKey] = info
peftData[peftKey] = data
}
// Write PEFT safetensors.
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil {
return fmt.Errorf("write safetensors: %w", err)
}
// Read MLX config for LoRA parameters.
cfgData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
var mlxConfig struct {
LoraParameters struct {
Rank int `json:"rank"`
Scale float64 `json:"scale"`
Dropout float64 `json:"dropout"`
} `json:"lora_parameters"`
}
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
return fmt.Errorf("parse config: %w", err)
}
rank := mlxConfig.LoraParameters.Rank
if rank == 0 {
rank = 8
}
scale := mlxConfig.LoraParameters.Scale
if scale == 0 {
scale = 20.0
}
// Determine target modules from tensor keys.
modules := make(map[string]bool)
layers := make(map[int]bool)
for k := range tensors {
if m := moduleRe.FindStringSubmatch(k); m != nil {
parts := strings.Split(m[1], ".")
modules[parts[len(parts)-1]] = true
}
if m := layerRe.FindStringSubmatch(k); m != nil {
n, _ := strconv.Atoi(m[1])
layers[n] = true
}
}
sortedModules := make([]string, 0, len(modules))
for m := range modules {
sortedModules = append(sortedModules, m)
}
sort.Strings(sortedModules)
sortedLayers := make([]int, 0, len(layers))
for l := range layers {
sortedLayers = append(sortedLayers, l)
}
sort.Ints(sortedLayers)
// Write PEFT adapter_config.json.
peftConfig := map[string]any{
"auto_mapping": nil,
"base_model_name_or_path": baseModelName,
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layers_pattern": nil,
"layers_to_transform": sortedLayers,
"lora_alpha": math.Round(scale * float64(rank)),
"lora_dropout": mlxConfig.LoraParameters.Dropout,
"modules_to_save": nil,
"peft_type": "LORA",
"r": rank,
"revision": nil,
"target_modules": sortedModules,
"task_type": "CAUSAL_LM",
}
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
if err != nil {
return fmt.Errorf("marshal peft config: %w", err)
}
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
return fmt.Errorf("write adapter_config.json: %w", err)
}
log.Printf("converted %d tensors, %d layers, target modules: %v",
len(peftTensors), len(sortedLayers), sortedModules)
return nil
}