350 lines
9.7 KiB
Go
350 lines
9.7 KiB
Go
|
|
package lem
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"encoding/binary"
|
|||
|
|
"encoding/json"
|
|||
|
|
"flag"
|
|||
|
|
"fmt"
|
|||
|
|
"log"
|
|||
|
|
"math"
|
|||
|
|
"os"
|
|||
|
|
"path/filepath"
|
|||
|
|
"regexp"
|
|||
|
|
"sort"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// RunConvert is the CLI entry point for the convert command.
|
|||
|
|
// Converts MLX LoRA adapters to HuggingFace PEFT format:
|
|||
|
|
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
|
|||
|
|
// - Transpose: MLX (in, rank) → PEFT (rank, in)
|
|||
|
|
// - Config generation: adapter_config.json with lora_alpha = scale × rank
|
|||
|
|
func RunConvert(args []string) {
|
|||
|
|
fs := flag.NewFlagSet("convert", flag.ExitOnError)
|
|||
|
|
|
|||
|
|
safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)")
|
|||
|
|
configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)")
|
|||
|
|
outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter")
|
|||
|
|
baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
|
|||
|
|
|
|||
|
|
if err := fs.Parse(args); err != nil {
|
|||
|
|
log.Fatalf("parse flags: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if *safetensorsPath == "" || *configPath == "" {
|
|||
|
|
fmt.Fprintln(os.Stderr, "error: --input and --config are required")
|
|||
|
|
fs.Usage()
|
|||
|
|
os.Exit(1)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil {
|
|||
|
|
log.Fatalf("convert: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fmt.Printf("Converted to: %s\n", *outputDir)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var (
|
|||
|
|
loraARe = regexp.MustCompile(`\.lora_a$`)
|
|||
|
|
loraBRe = regexp.MustCompile(`\.lora_b$`)
|
|||
|
|
layerRe = regexp.MustCompile(`layers\.(\d+)`)
|
|||
|
|
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// renameMLXKey converts an MLX tensor key to PEFT format.
|
|||
|
|
func renameMLXKey(mlxKey string) string {
|
|||
|
|
key := mlxKey
|
|||
|
|
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
|
|||
|
|
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
|
|||
|
|
key = "base_model.model." + key
|
|||
|
|
return key
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// safetensorsHeader represents the header of a safetensors file.
|
|||
|
|
type safetensorsHeader struct {
|
|||
|
|
Metadata map[string]string `json:"__metadata__,omitempty"`
|
|||
|
|
Tensors map[string]safetensorsTensorInfo `json:"-"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type safetensorsTensorInfo struct {
|
|||
|
|
Dtype string `json:"dtype"`
|
|||
|
|
Shape []int `json:"shape"`
|
|||
|
|
DataOffsets [2]int `json:"data_offsets"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// readSafetensors reads a safetensors file and returns tensor name→data+info pairs.
|
|||
|
|
func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) {
|
|||
|
|
data, err := os.ReadFile(path)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("read file: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if len(data) < 8 {
|
|||
|
|
return nil, nil, fmt.Errorf("file too small")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
|
|||
|
|
if 8+headerSize > len(data) {
|
|||
|
|
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
headerJSON := data[8 : 8+headerSize]
|
|||
|
|
tensorData := data[8+headerSize:]
|
|||
|
|
|
|||
|
|
// Parse header as a generic map since tensors are top-level keys.
|
|||
|
|
var rawHeader map[string]json.RawMessage
|
|||
|
|
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("parse header: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tensors := make(map[string]safetensorsTensorInfo)
|
|||
|
|
for key, raw := range rawHeader {
|
|||
|
|
if key == "__metadata__" {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
var info safetensorsTensorInfo
|
|||
|
|
if err := json.Unmarshal(raw, &info); err != nil {
|
|||
|
|
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
|
|||
|
|
}
|
|||
|
|
tensors[key] = info
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return tensors, tensorData, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getTensorData extracts raw bytes for a tensor from the data section.
|
|||
|
|
func getTensorData(info safetensorsTensorInfo, allData []byte) []byte {
|
|||
|
|
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
|
|||
|
|
func transposeFloat32(data []byte, rows, cols int) []byte {
|
|||
|
|
if len(data) != rows*cols*4 {
|
|||
|
|
return data // size mismatch, return as-is
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
result := make([]byte, len(data))
|
|||
|
|
for r := 0; r < rows; r++ {
|
|||
|
|
for c := 0; c < cols; c++ {
|
|||
|
|
srcOff := (r*cols + c) * 4
|
|||
|
|
dstOff := (c*rows + r) * 4
|
|||
|
|
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
|
|||
|
|
func transposeFloat16(data []byte, rows, cols int) []byte {
|
|||
|
|
if len(data) != rows*cols*2 {
|
|||
|
|
return data
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
result := make([]byte, len(data))
|
|||
|
|
for r := 0; r < rows; r++ {
|
|||
|
|
for c := 0; c < cols; c++ {
|
|||
|
|
srcOff := (r*cols + c) * 2
|
|||
|
|
dstOff := (c*rows + r) * 2
|
|||
|
|
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return result
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
|
|||
|
|
func transposeBFloat16(data []byte, rows, cols int) []byte {
|
|||
|
|
return transposeFloat16(data, rows, cols) // same element size
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// writeSafetensors writes tensors to a safetensors file.
|
|||
|
|
func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error {
|
|||
|
|
// Sort keys for deterministic output.
|
|||
|
|
keys := make([]string, 0, len(tensors))
|
|||
|
|
for k := range tensors {
|
|||
|
|
keys = append(keys, k)
|
|||
|
|
}
|
|||
|
|
sort.Strings(keys)
|
|||
|
|
|
|||
|
|
// Compute offsets.
|
|||
|
|
offset := 0
|
|||
|
|
updatedTensors := make(map[string]safetensorsTensorInfo)
|
|||
|
|
for _, k := range keys {
|
|||
|
|
info := tensors[k]
|
|||
|
|
data := tensorData[k]
|
|||
|
|
info.DataOffsets = [2]int{offset, offset + len(data)}
|
|||
|
|
updatedTensors[k] = info
|
|||
|
|
offset += len(data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Build header JSON.
|
|||
|
|
headerMap := make(map[string]interface{})
|
|||
|
|
for k, info := range updatedTensors {
|
|||
|
|
headerMap[k] = info
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
headerJSON, err := json.Marshal(headerMap)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("marshal header: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Write file: 8-byte header size + header JSON + tensor data.
|
|||
|
|
f, err := os.Create(path)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("create %s: %w", path, err)
|
|||
|
|
}
|
|||
|
|
defer f.Close()
|
|||
|
|
|
|||
|
|
headerSizeBuf := make([]byte, 8)
|
|||
|
|
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
|
|||
|
|
|
|||
|
|
if _, err := f.Write(headerSizeBuf); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
if _, err := f.Write(headerJSON); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, k := range keys {
|
|||
|
|
if _, err := f.Write(tensorData[k]); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format.
|
|||
|
|
func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
|
|||
|
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
|||
|
|
return fmt.Errorf("create output dir: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Read MLX tensors.
|
|||
|
|
tensors, tensorData, err := readSafetensors(safetensorsPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("read safetensors: %w", err)
|
|||
|
|
}
|
|||
|
|
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
|
|||
|
|
|
|||
|
|
// Rename and transpose tensors.
|
|||
|
|
peftTensors := make(map[string]safetensorsTensorInfo)
|
|||
|
|
peftData := make(map[string][]byte)
|
|||
|
|
|
|||
|
|
for mlxKey, info := range tensors {
|
|||
|
|
peftKey := renameMLXKey(mlxKey)
|
|||
|
|
data := getTensorData(info, tensorData)
|
|||
|
|
|
|||
|
|
// Transpose: swap shape and transpose data.
|
|||
|
|
if len(info.Shape) == 2 {
|
|||
|
|
rows, cols := info.Shape[0], info.Shape[1]
|
|||
|
|
|
|||
|
|
switch info.Dtype {
|
|||
|
|
case "F32":
|
|||
|
|
data = transposeFloat32(data, rows, cols)
|
|||
|
|
case "F16":
|
|||
|
|
data = transposeFloat16(data, rows, cols)
|
|||
|
|
case "BF16":
|
|||
|
|
data = transposeBFloat16(data, rows, cols)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
info.Shape = []int{cols, rows}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
peftTensors[peftKey] = info
|
|||
|
|
peftData[peftKey] = data
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Write PEFT safetensors.
|
|||
|
|
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
|
|||
|
|
if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil {
|
|||
|
|
return fmt.Errorf("write safetensors: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Read MLX config for LoRA parameters.
|
|||
|
|
cfgData, err := os.ReadFile(configPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("read config: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var mlxConfig struct {
|
|||
|
|
LoraParameters struct {
|
|||
|
|
Rank int `json:"rank"`
|
|||
|
|
Scale float64 `json:"scale"`
|
|||
|
|
Dropout float64 `json:"dropout"`
|
|||
|
|
} `json:"lora_parameters"`
|
|||
|
|
}
|
|||
|
|
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
|
|||
|
|
return fmt.Errorf("parse config: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
rank := mlxConfig.LoraParameters.Rank
|
|||
|
|
if rank == 0 {
|
|||
|
|
rank = 8
|
|||
|
|
}
|
|||
|
|
scale := mlxConfig.LoraParameters.Scale
|
|||
|
|
if scale == 0 {
|
|||
|
|
scale = 20.0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Determine target modules from tensor keys.
|
|||
|
|
modules := make(map[string]bool)
|
|||
|
|
layers := make(map[int]bool)
|
|||
|
|
for k := range tensors {
|
|||
|
|
if m := moduleRe.FindStringSubmatch(k); m != nil {
|
|||
|
|
parts := strings.Split(m[1], ".")
|
|||
|
|
modules[parts[len(parts)-1]] = true
|
|||
|
|
}
|
|||
|
|
if m := layerRe.FindStringSubmatch(k); m != nil {
|
|||
|
|
n, _ := strconv.Atoi(m[1])
|
|||
|
|
layers[n] = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
sortedModules := make([]string, 0, len(modules))
|
|||
|
|
for m := range modules {
|
|||
|
|
sortedModules = append(sortedModules, m)
|
|||
|
|
}
|
|||
|
|
sort.Strings(sortedModules)
|
|||
|
|
|
|||
|
|
sortedLayers := make([]int, 0, len(layers))
|
|||
|
|
for l := range layers {
|
|||
|
|
sortedLayers = append(sortedLayers, l)
|
|||
|
|
}
|
|||
|
|
sort.Ints(sortedLayers)
|
|||
|
|
|
|||
|
|
// Write PEFT adapter_config.json.
|
|||
|
|
peftConfig := map[string]interface{}{
|
|||
|
|
"auto_mapping": nil,
|
|||
|
|
"base_model_name_or_path": baseModelName,
|
|||
|
|
"bias": "none",
|
|||
|
|
"fan_in_fan_out": false,
|
|||
|
|
"inference_mode": true,
|
|||
|
|
"init_lora_weights": true,
|
|||
|
|
"layers_pattern": nil,
|
|||
|
|
"layers_to_transform": sortedLayers,
|
|||
|
|
"lora_alpha": math.Round(scale * float64(rank)),
|
|||
|
|
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
|||
|
|
"modules_to_save": nil,
|
|||
|
|
"peft_type": "LORA",
|
|||
|
|
"r": rank,
|
|||
|
|
"revision": nil,
|
|||
|
|
"target_modules": sortedModules,
|
|||
|
|
"task_type": "CAUSAL_LM",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
|
|||
|
|
if err != nil {
|
|||
|
|
return fmt.Errorf("marshal peft config: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
|
|||
|
|
return fmt.Errorf("write adapter_config.json: %w", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
log.Printf("converted %d tensors, %d layers, target modules: %v",
|
|||
|
|
len(peftTensors), len(sortedLayers), sortedModules)
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|