304 lines
8 KiB
Go
304 lines
8 KiB
Go
|
|
package ml
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/binary"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"log"
|
||
|
|
"math"
|
||
|
|
"os"
|
||
|
|
"path/filepath"
|
||
|
|
"regexp"
|
||
|
|
"sort"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
var (
|
||
|
|
loraARe = regexp.MustCompile(`\.lora_a$`)
|
||
|
|
loraBRe = regexp.MustCompile(`\.lora_b$`)
|
||
|
|
layerRe = regexp.MustCompile(`layers\.(\d+)`)
|
||
|
|
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
|
||
|
|
)
|
||
|
|
|
||
|
|
// RenameMLXKey converts an MLX tensor key to PEFT format.
|
||
|
|
func RenameMLXKey(mlxKey string) string {
|
||
|
|
key := mlxKey
|
||
|
|
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
|
||
|
|
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
|
||
|
|
key = "base_model.model." + key
|
||
|
|
return key
|
||
|
|
}
|
||
|
|
|
||
|
|
// SafetensorsHeader represents the header of a safetensors file.
|
||
|
|
type SafetensorsHeader struct {
|
||
|
|
Metadata map[string]string `json:"__metadata__,omitempty"`
|
||
|
|
Tensors map[string]SafetensorsTensorInfo `json:"-"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// SafetensorsTensorInfo describes a tensor's dtype, shape, and data location.
|
||
|
|
type SafetensorsTensorInfo struct {
|
||
|
|
Dtype string `json:"dtype"`
|
||
|
|
Shape []int `json:"shape"`
|
||
|
|
DataOffsets [2]int `json:"data_offsets"`
|
||
|
|
}
|
||
|
|
|
||
|
|
// ReadSafetensors reads a safetensors file and returns tensor info and raw data.
|
||
|
|
func ReadSafetensors(path string) (map[string]SafetensorsTensorInfo, []byte, error) {
|
||
|
|
data, err := os.ReadFile(path)
|
||
|
|
if err != nil {
|
||
|
|
return nil, nil, fmt.Errorf("read file: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(data) < 8 {
|
||
|
|
return nil, nil, fmt.Errorf("file too small")
|
||
|
|
}
|
||
|
|
|
||
|
|
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
|
||
|
|
if 8+headerSize > len(data) {
|
||
|
|
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
|
||
|
|
}
|
||
|
|
|
||
|
|
headerJSON := data[8 : 8+headerSize]
|
||
|
|
tensorData := data[8+headerSize:]
|
||
|
|
|
||
|
|
var rawHeader map[string]json.RawMessage
|
||
|
|
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
|
||
|
|
return nil, nil, fmt.Errorf("parse header: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
tensors := make(map[string]SafetensorsTensorInfo)
|
||
|
|
for key, raw := range rawHeader {
|
||
|
|
if key == "__metadata__" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
var info SafetensorsTensorInfo
|
||
|
|
if err := json.Unmarshal(raw, &info); err != nil {
|
||
|
|
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
|
||
|
|
}
|
||
|
|
tensors[key] = info
|
||
|
|
}
|
||
|
|
|
||
|
|
return tensors, tensorData, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GetTensorData extracts raw bytes for a tensor from the data section.
|
||
|
|
func GetTensorData(info SafetensorsTensorInfo, allData []byte) []byte {
|
||
|
|
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
|
||
|
|
}
|
||
|
|
|
||
|
|
// TransposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
|
||
|
|
func TransposeFloat32(data []byte, rows, cols int) []byte {
|
||
|
|
if len(data) != rows*cols*4 {
|
||
|
|
return data
|
||
|
|
}
|
||
|
|
result := make([]byte, len(data))
|
||
|
|
for r := range rows {
|
||
|
|
for c := range cols {
|
||
|
|
srcOff := (r*cols + c) * 4
|
||
|
|
dstOff := (c*rows + r) * 4
|
||
|
|
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
}
|
||
|
|
|
||
|
|
// TransposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
|
||
|
|
func TransposeFloat16(data []byte, rows, cols int) []byte {
|
||
|
|
if len(data) != rows*cols*2 {
|
||
|
|
return data
|
||
|
|
}
|
||
|
|
result := make([]byte, len(data))
|
||
|
|
for r := range rows {
|
||
|
|
for c := range cols {
|
||
|
|
srcOff := (r*cols + c) * 2
|
||
|
|
dstOff := (c*rows + r) * 2
|
||
|
|
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return result
|
||
|
|
}
|
||
|
|
|
||
|
|
// TransposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
|
||
|
|
func TransposeBFloat16(data []byte, rows, cols int) []byte {
|
||
|
|
return TransposeFloat16(data, rows, cols)
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteSafetensors writes tensors to a safetensors file.
|
||
|
|
func WriteSafetensors(path string, tensors map[string]SafetensorsTensorInfo, tensorData map[string][]byte) error {
|
||
|
|
keys := make([]string, 0, len(tensors))
|
||
|
|
for k := range tensors {
|
||
|
|
keys = append(keys, k)
|
||
|
|
}
|
||
|
|
sort.Strings(keys)
|
||
|
|
|
||
|
|
offset := 0
|
||
|
|
updatedTensors := make(map[string]SafetensorsTensorInfo)
|
||
|
|
for _, k := range keys {
|
||
|
|
info := tensors[k]
|
||
|
|
data := tensorData[k]
|
||
|
|
info.DataOffsets = [2]int{offset, offset + len(data)}
|
||
|
|
updatedTensors[k] = info
|
||
|
|
offset += len(data)
|
||
|
|
}
|
||
|
|
|
||
|
|
headerMap := make(map[string]interface{})
|
||
|
|
for k, info := range updatedTensors {
|
||
|
|
headerMap[k] = info
|
||
|
|
}
|
||
|
|
|
||
|
|
headerJSON, err := json.Marshal(headerMap)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("marshal header: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
f, err := os.Create(path)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("create %s: %w", path, err)
|
||
|
|
}
|
||
|
|
defer f.Close()
|
||
|
|
|
||
|
|
headerSizeBuf := make([]byte, 8)
|
||
|
|
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
|
||
|
|
|
||
|
|
if _, err := f.Write(headerSizeBuf); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if _, err := f.Write(headerJSON); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, k := range keys {
|
||
|
|
if _, err := f.Write(tensorData[k]); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ConvertMLXtoPEFT converts an MLX LoRA adapter to HuggingFace PEFT format.
|
||
|
|
func ConvertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
|
||
|
|
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||
|
|
return fmt.Errorf("create output dir: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
tensors, tensorData, err := ReadSafetensors(safetensorsPath)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("read safetensors: %w", err)
|
||
|
|
}
|
||
|
|
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
|
||
|
|
|
||
|
|
peftTensors := make(map[string]SafetensorsTensorInfo)
|
||
|
|
peftData := make(map[string][]byte)
|
||
|
|
|
||
|
|
for mlxKey, info := range tensors {
|
||
|
|
peftKey := RenameMLXKey(mlxKey)
|
||
|
|
data := GetTensorData(info, tensorData)
|
||
|
|
|
||
|
|
if len(info.Shape) == 2 {
|
||
|
|
rows, cols := info.Shape[0], info.Shape[1]
|
||
|
|
switch info.Dtype {
|
||
|
|
case "F32":
|
||
|
|
data = TransposeFloat32(data, rows, cols)
|
||
|
|
case "F16":
|
||
|
|
data = TransposeFloat16(data, rows, cols)
|
||
|
|
case "BF16":
|
||
|
|
data = TransposeBFloat16(data, rows, cols)
|
||
|
|
}
|
||
|
|
info.Shape = []int{cols, rows}
|
||
|
|
}
|
||
|
|
|
||
|
|
peftTensors[peftKey] = info
|
||
|
|
peftData[peftKey] = data
|
||
|
|
}
|
||
|
|
|
||
|
|
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
|
||
|
|
if err := WriteSafetensors(outSafetensors, peftTensors, peftData); err != nil {
|
||
|
|
return fmt.Errorf("write safetensors: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
cfgData, err := os.ReadFile(configPath)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("read config: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
var mlxConfig struct {
|
||
|
|
LoraParameters struct {
|
||
|
|
Rank int `json:"rank"`
|
||
|
|
Scale float64 `json:"scale"`
|
||
|
|
Dropout float64 `json:"dropout"`
|
||
|
|
} `json:"lora_parameters"`
|
||
|
|
}
|
||
|
|
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
|
||
|
|
return fmt.Errorf("parse config: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
rank := mlxConfig.LoraParameters.Rank
|
||
|
|
if rank == 0 {
|
||
|
|
rank = 8
|
||
|
|
}
|
||
|
|
scale := mlxConfig.LoraParameters.Scale
|
||
|
|
if scale == 0 {
|
||
|
|
scale = 20.0
|
||
|
|
}
|
||
|
|
|
||
|
|
modules := make(map[string]bool)
|
||
|
|
layers := make(map[int]bool)
|
||
|
|
for k := range tensors {
|
||
|
|
if m := moduleRe.FindStringSubmatch(k); m != nil {
|
||
|
|
parts := strings.Split(m[1], ".")
|
||
|
|
modules[parts[len(parts)-1]] = true
|
||
|
|
}
|
||
|
|
if m := layerRe.FindStringSubmatch(k); m != nil {
|
||
|
|
n, _ := strconv.Atoi(m[1])
|
||
|
|
layers[n] = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
sortedModules := make([]string, 0, len(modules))
|
||
|
|
for m := range modules {
|
||
|
|
sortedModules = append(sortedModules, m)
|
||
|
|
}
|
||
|
|
sort.Strings(sortedModules)
|
||
|
|
|
||
|
|
sortedLayers := make([]int, 0, len(layers))
|
||
|
|
for l := range layers {
|
||
|
|
sortedLayers = append(sortedLayers, l)
|
||
|
|
}
|
||
|
|
sort.Ints(sortedLayers)
|
||
|
|
|
||
|
|
peftConfig := map[string]interface{}{
|
||
|
|
"auto_mapping": nil,
|
||
|
|
"base_model_name_or_path": baseModelName,
|
||
|
|
"bias": "none",
|
||
|
|
"fan_in_fan_out": false,
|
||
|
|
"inference_mode": true,
|
||
|
|
"init_lora_weights": true,
|
||
|
|
"layers_pattern": nil,
|
||
|
|
"layers_to_transform": sortedLayers,
|
||
|
|
"lora_alpha": math.Round(scale * float64(rank)),
|
||
|
|
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
||
|
|
"modules_to_save": nil,
|
||
|
|
"peft_type": "LORA",
|
||
|
|
"r": rank,
|
||
|
|
"revision": nil,
|
||
|
|
"target_modules": sortedModules,
|
||
|
|
"task_type": "CAUSAL_LM",
|
||
|
|
}
|
||
|
|
|
||
|
|
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("marshal peft config: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
|
||
|
|
return fmt.Errorf("write adapter_config.json: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
log.Printf("converted %d tensors, %d layers, target modules: %v",
|
||
|
|
len(peftTensors), len(sortedLayers), sortedModules)
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|