1
0
Fork 0
forked from lthn/LEM
LEM/pkg/lem/convert.go

350 lines
9.7 KiB
Go
Raw Normal View History

package lem
import (
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"log"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
)
// RunConvert is the CLI entry point for the convert command.
// Converts MLX LoRA adapters to HuggingFace PEFT format:
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
// - Transpose: MLX (in, rank) → PEFT (rank, in)
// - Config generation: adapter_config.json with lora_alpha = scale × rank
func RunConvert(args []string) {
fs := flag.NewFlagSet("convert", flag.ExitOnError)
safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)")
configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)")
outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter")
baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if *safetensorsPath == "" || *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --input and --config are required")
fs.Usage()
os.Exit(1)
}
if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil {
log.Fatalf("convert: %v", err)
}
fmt.Printf("Converted to: %s\n", *outputDir)
}
var (
loraARe = regexp.MustCompile(`\.lora_a$`)
loraBRe = regexp.MustCompile(`\.lora_b$`)
layerRe = regexp.MustCompile(`layers\.(\d+)`)
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
)
// renameMLXKey converts an MLX tensor key to PEFT format.
func renameMLXKey(mlxKey string) string {
key := mlxKey
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
key = "base_model.model." + key
return key
}
// safetensorsHeader represents the header of a safetensors file.
type safetensorsHeader struct {
Metadata map[string]string `json:"__metadata__,omitempty"`
Tensors map[string]safetensorsTensorInfo `json:"-"`
}
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// readSafetensors reads a safetensors file and returns tensor name→data+info pairs.
func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, fmt.Errorf("read file: %w", err)
}
if len(data) < 8 {
return nil, nil, fmt.Errorf("file too small")
}
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
if 8+headerSize > len(data) {
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
}
headerJSON := data[8 : 8+headerSize]
tensorData := data[8+headerSize:]
// Parse header as a generic map since tensors are top-level keys.
var rawHeader map[string]json.RawMessage
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
return nil, nil, fmt.Errorf("parse header: %w", err)
}
tensors := make(map[string]safetensorsTensorInfo)
for key, raw := range rawHeader {
if key == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
}
tensors[key] = info
}
return tensors, tensorData, nil
}
// getTensorData extracts raw bytes for a tensor from the data section.
func getTensorData(info safetensorsTensorInfo, allData []byte) []byte {
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
}
// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
func transposeFloat32(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*4 {
return data // size mismatch, return as-is
}
result := make([]byte, len(data))
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
srcOff := (r*cols + c) * 4
dstOff := (c*rows + r) * 4
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
}
}
return result
}
// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
func transposeFloat16(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*2 {
return data
}
result := make([]byte, len(data))
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
srcOff := (r*cols + c) * 2
dstOff := (c*rows + r) * 2
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
}
}
return result
}
// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
func transposeBFloat16(data []byte, rows, cols int) []byte {
return transposeFloat16(data, rows, cols) // same element size
}
// writeSafetensors writes tensors to a safetensors file.
func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error {
// Sort keys for deterministic output.
keys := make([]string, 0, len(tensors))
for k := range tensors {
keys = append(keys, k)
}
sort.Strings(keys)
// Compute offsets.
offset := 0
updatedTensors := make(map[string]safetensorsTensorInfo)
for _, k := range keys {
info := tensors[k]
data := tensorData[k]
info.DataOffsets = [2]int{offset, offset + len(data)}
updatedTensors[k] = info
offset += len(data)
}
// Build header JSON.
headerMap := make(map[string]interface{})
for k, info := range updatedTensors {
headerMap[k] = info
}
headerJSON, err := json.Marshal(headerMap)
if err != nil {
return fmt.Errorf("marshal header: %w", err)
}
// Write file: 8-byte header size + header JSON + tensor data.
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
headerSizeBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
if _, err := f.Write(headerSizeBuf); err != nil {
return err
}
if _, err := f.Write(headerJSON); err != nil {
return err
}
for _, k := range keys {
if _, err := f.Write(tensorData[k]); err != nil {
return err
}
}
return nil
}
// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format.
func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
// Read MLX tensors.
tensors, tensorData, err := readSafetensors(safetensorsPath)
if err != nil {
return fmt.Errorf("read safetensors: %w", err)
}
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
// Rename and transpose tensors.
peftTensors := make(map[string]safetensorsTensorInfo)
peftData := make(map[string][]byte)
for mlxKey, info := range tensors {
peftKey := renameMLXKey(mlxKey)
data := getTensorData(info, tensorData)
// Transpose: swap shape and transpose data.
if len(info.Shape) == 2 {
rows, cols := info.Shape[0], info.Shape[1]
switch info.Dtype {
case "F32":
data = transposeFloat32(data, rows, cols)
case "F16":
data = transposeFloat16(data, rows, cols)
case "BF16":
data = transposeBFloat16(data, rows, cols)
}
info.Shape = []int{cols, rows}
}
peftTensors[peftKey] = info
peftData[peftKey] = data
}
// Write PEFT safetensors.
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil {
return fmt.Errorf("write safetensors: %w", err)
}
// Read MLX config for LoRA parameters.
cfgData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
var mlxConfig struct {
LoraParameters struct {
Rank int `json:"rank"`
Scale float64 `json:"scale"`
Dropout float64 `json:"dropout"`
} `json:"lora_parameters"`
}
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
return fmt.Errorf("parse config: %w", err)
}
rank := mlxConfig.LoraParameters.Rank
if rank == 0 {
rank = 8
}
scale := mlxConfig.LoraParameters.Scale
if scale == 0 {
scale = 20.0
}
// Determine target modules from tensor keys.
modules := make(map[string]bool)
layers := make(map[int]bool)
for k := range tensors {
if m := moduleRe.FindStringSubmatch(k); m != nil {
parts := strings.Split(m[1], ".")
modules[parts[len(parts)-1]] = true
}
if m := layerRe.FindStringSubmatch(k); m != nil {
n, _ := strconv.Atoi(m[1])
layers[n] = true
}
}
sortedModules := make([]string, 0, len(modules))
for m := range modules {
sortedModules = append(sortedModules, m)
}
sort.Strings(sortedModules)
sortedLayers := make([]int, 0, len(layers))
for l := range layers {
sortedLayers = append(sortedLayers, l)
}
sort.Ints(sortedLayers)
// Write PEFT adapter_config.json.
peftConfig := map[string]interface{}{
"auto_mapping": nil,
"base_model_name_or_path": baseModelName,
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layers_pattern": nil,
"layers_to_transform": sortedLayers,
"lora_alpha": math.Round(scale * float64(rank)),
"lora_dropout": mlxConfig.LoraParameters.Dropout,
"modules_to_save": nil,
"peft_type": "LORA",
"r": rank,
"revision": nil,
"target_modules": sortedModules,
"task_type": "CAUSAL_LM",
}
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
if err != nil {
return fmt.Errorf("marshal peft config: %w", err)
}
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
return fmt.Errorf("write adapter_config.json: %w", err)
}
log.Printf("converted %d tensors, %d layers, target modules: %v",
len(peftTensors), len(sortedLayers), sortedModules)
return nil
}