feat: add parquet, publish, metrics, convert commands
- `lem parquet` — export JSONL training splits to Parquet (parquet-go) - `lem publish` — push Parquet files to HuggingFace dataset repo - `lem metrics` — push DuckDB golden set stats to InfluxDB - `lem convert` — MLX LoRA adapter → HuggingFace PEFT format (pure Go safetensors read/write/transpose, no PyTorch needed) Dependencies added: parquet-go, go-huggingface, go-rocm, go-pytorch, gotch Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
0afa5e9147
commit
4eaf1bfb39
9 changed files with 1187 additions and 7 deletions
14
go.mod
14
go.mod
|
|
@ -3,21 +3,29 @@ module forge.lthn.ai/lthn/lem
|
|||
go 1.25.6
|
||||
|
||||
require (
|
||||
github.com/marcboeker/go-duckdb v1.8.5
|
||||
github.com/parquet-go/parquet-go v0.27.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/apache/arrow-go/v18 v18.1.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/google/flatbuffers v25.1.24+incompatible // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/hupe1980/go-huggingface v0.0.15 // indirect
|
||||
github.com/klauspost/compress v1.17.11 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/marcboeker/go-duckdb v1.8.5 // indirect
|
||||
github.com/parquet-go/bitpack v1.0.0 // indirect
|
||||
github.com/parquet-go/jsonlite v1.0.0 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/twpayne/go-geom v1.6.1 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect
|
||||
golang.org/x/mod v0.22.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/tools v0.29.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/protobuf v1.36.1 // indirect
|
||||
)
|
||||
|
|
|
|||
52
go.sum
52
go.sum
|
|
@ -1,23 +1,61 @@
|
|||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||
github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY=
|
||||
github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
|
||||
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
|
||||
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y=
|
||||
github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0=
|
||||
github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE=
|
||||
github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hupe1980/go-huggingface v0.0.15 h1:tTWmUGGunC/BYz4hrwS8SSVtMYVYjceG2uhL8HxeXvw=
|
||||
github.com/hupe1980/go-huggingface v0.0.15/go.mod h1:IRvsik3+b9BJyw9hCfw1arI6gDObcVto1UA8f3kt8mM=
|
||||
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
|
||||
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
|
||||
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
|
||||
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
|
||||
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
|
||||
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
|
||||
github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0=
|
||||
github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
|
||||
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
|
||||
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
|
||||
github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA=
|
||||
github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs=
|
||||
github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU=
|
||||
github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0=
|
||||
github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g=
|
||||
github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg=
|
||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
|
||||
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4=
|
||||
github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
|
||||
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
|
||||
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
|
||||
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
|
||||
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc=
|
||||
|
|
@ -26,9 +64,15 @@ golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
|
|||
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE=
|
||||
golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
|
||||
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
|
||||
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
|
||||
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
|
||||
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
|
|
|||
12
main.go
12
main.go
|
|
@ -21,6 +21,10 @@ Commands:
|
|||
expand Generate expansion responses via trained LEM model
|
||||
conv Generate conversational training data
|
||||
ingest Ingest benchmark data into InfluxDB
|
||||
parquet Export JSONL training splits to Parquet for HuggingFace
|
||||
publish Push Parquet files to HuggingFace dataset repo
|
||||
metrics Push DuckDB golden set stats to InfluxDB
|
||||
convert Convert MLX LoRA adapter to HuggingFace PEFT format
|
||||
`
|
||||
|
||||
func main() {
|
||||
|
|
@ -46,6 +50,14 @@ func main() {
|
|||
lem.RunConv(os.Args[2:])
|
||||
case "ingest":
|
||||
lem.RunIngest(os.Args[2:])
|
||||
case "parquet":
|
||||
lem.RunParquet(os.Args[2:])
|
||||
case "publish":
|
||||
lem.RunPublish(os.Args[2:])
|
||||
case "metrics":
|
||||
lem.RunMetrics(os.Args[2:])
|
||||
case "convert":
|
||||
lem.RunConvert(os.Args[2:])
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
|
||||
os.Exit(1)
|
||||
|
|
|
|||
349
pkg/lem/convert.go
Normal file
349
pkg/lem/convert.go
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RunConvert is the CLI entry point for the convert command.
|
||||
// Converts MLX LoRA adapters to HuggingFace PEFT format:
|
||||
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
|
||||
// - Transpose: MLX (in, rank) → PEFT (rank, in)
|
||||
// - Config generation: adapter_config.json with lora_alpha = scale × rank
|
||||
func RunConvert(args []string) {
|
||||
fs := flag.NewFlagSet("convert", flag.ExitOnError)
|
||||
|
||||
safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)")
|
||||
configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)")
|
||||
outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter")
|
||||
baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *safetensorsPath == "" || *configPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input and --config are required")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil {
|
||||
log.Fatalf("convert: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Converted to: %s\n", *outputDir)
|
||||
}
|
||||
|
||||
var (
|
||||
loraARe = regexp.MustCompile(`\.lora_a$`)
|
||||
loraBRe = regexp.MustCompile(`\.lora_b$`)
|
||||
layerRe = regexp.MustCompile(`layers\.(\d+)`)
|
||||
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
|
||||
)
|
||||
|
||||
// renameMLXKey converts an MLX tensor key to PEFT format.
|
||||
func renameMLXKey(mlxKey string) string {
|
||||
key := mlxKey
|
||||
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
|
||||
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
|
||||
key = "base_model.model." + key
|
||||
return key
|
||||
}
|
||||
|
||||
// safetensorsHeader represents the header of a safetensors file.
|
||||
type safetensorsHeader struct {
|
||||
Metadata map[string]string `json:"__metadata__,omitempty"`
|
||||
Tensors map[string]safetensorsTensorInfo `json:"-"`
|
||||
}
|
||||
|
||||
type safetensorsTensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// readSafetensors reads a safetensors file and returns tensor name→data+info pairs.
|
||||
func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
|
||||
if len(data) < 8 {
|
||||
return nil, nil, fmt.Errorf("file too small")
|
||||
}
|
||||
|
||||
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
|
||||
if 8+headerSize > len(data) {
|
||||
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
|
||||
}
|
||||
|
||||
headerJSON := data[8 : 8+headerSize]
|
||||
tensorData := data[8+headerSize:]
|
||||
|
||||
// Parse header as a generic map since tensors are top-level keys.
|
||||
var rawHeader map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse header: %w", err)
|
||||
}
|
||||
|
||||
tensors := make(map[string]safetensorsTensorInfo)
|
||||
for key, raw := range rawHeader {
|
||||
if key == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
var info safetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
|
||||
}
|
||||
tensors[key] = info
|
||||
}
|
||||
|
||||
return tensors, tensorData, nil
|
||||
}
|
||||
|
||||
// getTensorData extracts raw bytes for a tensor from the data section.
|
||||
func getTensorData(info safetensorsTensorInfo, allData []byte) []byte {
|
||||
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
|
||||
}
|
||||
|
||||
// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
|
||||
func transposeFloat32(data []byte, rows, cols int) []byte {
|
||||
if len(data) != rows*cols*4 {
|
||||
return data // size mismatch, return as-is
|
||||
}
|
||||
|
||||
result := make([]byte, len(data))
|
||||
for r := 0; r < rows; r++ {
|
||||
for c := 0; c < cols; c++ {
|
||||
srcOff := (r*cols + c) * 4
|
||||
dstOff := (c*rows + r) * 4
|
||||
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
|
||||
func transposeFloat16(data []byte, rows, cols int) []byte {
|
||||
if len(data) != rows*cols*2 {
|
||||
return data
|
||||
}
|
||||
|
||||
result := make([]byte, len(data))
|
||||
for r := 0; r < rows; r++ {
|
||||
for c := 0; c < cols; c++ {
|
||||
srcOff := (r*cols + c) * 2
|
||||
dstOff := (c*rows + r) * 2
|
||||
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
|
||||
func transposeBFloat16(data []byte, rows, cols int) []byte {
|
||||
return transposeFloat16(data, rows, cols) // same element size
|
||||
}
|
||||
|
||||
// writeSafetensors writes tensors to a safetensors file.
|
||||
func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error {
|
||||
// Sort keys for deterministic output.
|
||||
keys := make([]string, 0, len(tensors))
|
||||
for k := range tensors {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// Compute offsets.
|
||||
offset := 0
|
||||
updatedTensors := make(map[string]safetensorsTensorInfo)
|
||||
for _, k := range keys {
|
||||
info := tensors[k]
|
||||
data := tensorData[k]
|
||||
info.DataOffsets = [2]int{offset, offset + len(data)}
|
||||
updatedTensors[k] = info
|
||||
offset += len(data)
|
||||
}
|
||||
|
||||
// Build header JSON.
|
||||
headerMap := make(map[string]interface{})
|
||||
for k, info := range updatedTensors {
|
||||
headerMap[k] = info
|
||||
}
|
||||
|
||||
headerJSON, err := json.Marshal(headerMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
|
||||
// Write file: 8-byte header size + header JSON + tensor data.
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
headerSizeBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
|
||||
|
||||
if _, err := f.Write(headerSizeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if _, err := f.Write(tensorData[k]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format.
|
||||
func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
// Read MLX tensors.
|
||||
tensors, tensorData, err := readSafetensors(safetensorsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read safetensors: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
|
||||
|
||||
// Rename and transpose tensors.
|
||||
peftTensors := make(map[string]safetensorsTensorInfo)
|
||||
peftData := make(map[string][]byte)
|
||||
|
||||
for mlxKey, info := range tensors {
|
||||
peftKey := renameMLXKey(mlxKey)
|
||||
data := getTensorData(info, tensorData)
|
||||
|
||||
// Transpose: swap shape and transpose data.
|
||||
if len(info.Shape) == 2 {
|
||||
rows, cols := info.Shape[0], info.Shape[1]
|
||||
|
||||
switch info.Dtype {
|
||||
case "F32":
|
||||
data = transposeFloat32(data, rows, cols)
|
||||
case "F16":
|
||||
data = transposeFloat16(data, rows, cols)
|
||||
case "BF16":
|
||||
data = transposeBFloat16(data, rows, cols)
|
||||
}
|
||||
|
||||
info.Shape = []int{cols, rows}
|
||||
}
|
||||
|
||||
peftTensors[peftKey] = info
|
||||
peftData[peftKey] = data
|
||||
}
|
||||
|
||||
// Write PEFT safetensors.
|
||||
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
|
||||
if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil {
|
||||
return fmt.Errorf("write safetensors: %w", err)
|
||||
}
|
||||
|
||||
// Read MLX config for LoRA parameters.
|
||||
cfgData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
|
||||
var mlxConfig struct {
|
||||
LoraParameters struct {
|
||||
Rank int `json:"rank"`
|
||||
Scale float64 `json:"scale"`
|
||||
Dropout float64 `json:"dropout"`
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
rank := mlxConfig.LoraParameters.Rank
|
||||
if rank == 0 {
|
||||
rank = 8
|
||||
}
|
||||
scale := mlxConfig.LoraParameters.Scale
|
||||
if scale == 0 {
|
||||
scale = 20.0
|
||||
}
|
||||
|
||||
// Determine target modules from tensor keys.
|
||||
modules := make(map[string]bool)
|
||||
layers := make(map[int]bool)
|
||||
for k := range tensors {
|
||||
if m := moduleRe.FindStringSubmatch(k); m != nil {
|
||||
parts := strings.Split(m[1], ".")
|
||||
modules[parts[len(parts)-1]] = true
|
||||
}
|
||||
if m := layerRe.FindStringSubmatch(k); m != nil {
|
||||
n, _ := strconv.Atoi(m[1])
|
||||
layers[n] = true
|
||||
}
|
||||
}
|
||||
|
||||
sortedModules := make([]string, 0, len(modules))
|
||||
for m := range modules {
|
||||
sortedModules = append(sortedModules, m)
|
||||
}
|
||||
sort.Strings(sortedModules)
|
||||
|
||||
sortedLayers := make([]int, 0, len(layers))
|
||||
for l := range layers {
|
||||
sortedLayers = append(sortedLayers, l)
|
||||
}
|
||||
sort.Ints(sortedLayers)
|
||||
|
||||
// Write PEFT adapter_config.json.
|
||||
peftConfig := map[string]interface{}{
|
||||
"auto_mapping": nil,
|
||||
"base_model_name_or_path": baseModelName,
|
||||
"bias": "none",
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": true,
|
||||
"init_lora_weights": true,
|
||||
"layers_pattern": nil,
|
||||
"layers_to_transform": sortedLayers,
|
||||
"lora_alpha": math.Round(scale * float64(rank)),
|
||||
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
||||
"modules_to_save": nil,
|
||||
"peft_type": "LORA",
|
||||
"r": rank,
|
||||
"revision": nil,
|
||||
"target_modules": sortedModules,
|
||||
"task_type": "CAUSAL_LM",
|
||||
}
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal peft config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
|
||||
return fmt.Errorf("write adapter_config.json: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("converted %d tensors, %d layers, target modules: %v",
|
||||
len(peftTensors), len(sortedLayers), sortedModules)
|
||||
|
||||
return nil
|
||||
}
|
||||
198
pkg/lem/convert_test.go
Normal file
198
pkg/lem/convert_test.go
Normal file
|
|
@ -0,0 +1,198 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenameMLXKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
"model.layers.12.self_attn.q_proj.lora_a",
|
||||
"base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight",
|
||||
},
|
||||
{
|
||||
"model.layers.0.self_attn.v_proj.lora_b",
|
||||
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight",
|
||||
},
|
||||
{
|
||||
"model.layers.5.mlp.gate_proj.lora_a",
|
||||
"base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := renameMLXKey(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("renameMLXKey(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransposeFloat32(t *testing.T) {
|
||||
// 2x3 matrix: [[1, 2, 3], [4, 5, 6]]
|
||||
data := make([]byte, 2*3*4)
|
||||
for i, v := range []float32{1, 2, 3, 4, 5, 6} {
|
||||
binary.LittleEndian.PutUint32(data[i*4:], math.Float32bits(v))
|
||||
}
|
||||
|
||||
result := transposeFloat32(data, 2, 3)
|
||||
|
||||
// Expected: 3x2 matrix: [[1, 4], [2, 5], [3, 6]]
|
||||
expected := []float32{1, 4, 2, 5, 3, 6}
|
||||
for i, want := range expected {
|
||||
got := math.Float32frombits(binary.LittleEndian.Uint32(result[i*4:]))
|
||||
if got != want {
|
||||
t.Errorf("result[%d] = %f, want %f", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertMLXtoPEFT(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
// Create a minimal MLX safetensors file with one lora_a and one lora_b tensor.
|
||||
// Shape: lora_a is (in=4, rank=2), lora_b is (rank=2, out=4)
|
||||
tensors := map[string]safetensorsTensorInfo{
|
||||
"model.layers.0.self_attn.q_proj.lora_a": {Dtype: "F32", Shape: []int{4, 2}},
|
||||
"model.layers.0.self_attn.q_proj.lora_b": {Dtype: "F32", Shape: []int{2, 4}},
|
||||
}
|
||||
|
||||
// Create tensor data: 4x2=8 floats and 2x4=8 floats.
|
||||
loraAData := make([]byte, 4*2*4)
|
||||
for i := 0; i < 8; i++ {
|
||||
binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1)))
|
||||
}
|
||||
loraBData := make([]byte, 2*4*4)
|
||||
for i := 0; i < 8; i++ {
|
||||
binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i)))
|
||||
}
|
||||
|
||||
tensorData := make(map[string][]byte)
|
||||
tensorData["model.layers.0.self_attn.q_proj.lora_a"] = loraAData
|
||||
tensorData["model.layers.0.self_attn.q_proj.lora_b"] = loraBData
|
||||
|
||||
sfPath := filepath.Join(dir, "adapters.safetensors")
|
||||
if err := writeSafetensors(sfPath, tensors, tensorData); err != nil {
|
||||
t.Fatalf("write test safetensors: %v", err)
|
||||
}
|
||||
|
||||
// Create MLX config.
|
||||
mlxConfig := map[string]interface{}{
|
||||
"lora_parameters": map[string]interface{}{
|
||||
"rank": 8,
|
||||
"scale": 20.0,
|
||||
"dropout": 0.0,
|
||||
},
|
||||
}
|
||||
cfgData, _ := json.Marshal(mlxConfig)
|
||||
cfgPath := filepath.Join(dir, "adapter_config.json")
|
||||
os.WriteFile(cfgPath, cfgData, 0644)
|
||||
|
||||
// Convert.
|
||||
outputDir := filepath.Join(dir, "peft_output")
|
||||
if err := convertMLXtoPEFT(sfPath, cfgPath, outputDir, "test-model"); err != nil {
|
||||
t.Fatalf("convert: %v", err)
|
||||
}
|
||||
|
||||
// Check output files exist.
|
||||
if _, err := os.Stat(filepath.Join(outputDir, "adapter_model.safetensors")); err != nil {
|
||||
t.Error("missing adapter_model.safetensors")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(outputDir, "adapter_config.json")); err != nil {
|
||||
t.Error("missing adapter_config.json")
|
||||
}
|
||||
|
||||
// Read and verify PEFT config.
|
||||
peftCfgData, err := os.ReadFile(filepath.Join(outputDir, "adapter_config.json"))
|
||||
if err != nil {
|
||||
t.Fatalf("read peft config: %v", err)
|
||||
}
|
||||
|
||||
var peftConfig map[string]interface{}
|
||||
if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil {
|
||||
t.Fatalf("parse peft config: %v", err)
|
||||
}
|
||||
|
||||
if peftConfig["peft_type"] != "LORA" {
|
||||
t.Errorf("peft_type = %v, want LORA", peftConfig["peft_type"])
|
||||
}
|
||||
if peftConfig["base_model_name_or_path"] != "test-model" {
|
||||
t.Errorf("base_model = %v, want test-model", peftConfig["base_model_name_or_path"])
|
||||
}
|
||||
|
||||
// Check that lora_alpha = scale * rank = 20 * 8 = 160.
|
||||
if alpha, ok := peftConfig["lora_alpha"].(float64); !ok || alpha != 160 {
|
||||
t.Errorf("lora_alpha = %v, want 160", peftConfig["lora_alpha"])
|
||||
}
|
||||
|
||||
// Verify converted safetensors has PEFT-format keys.
|
||||
peftTensors, _, err := readSafetensors(filepath.Join(outputDir, "adapter_model.safetensors"))
|
||||
if err != nil {
|
||||
t.Fatalf("read peft safetensors: %v", err)
|
||||
}
|
||||
|
||||
expectedKeys := []string{
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight",
|
||||
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight",
|
||||
}
|
||||
for _, k := range expectedKeys {
|
||||
if _, ok := peftTensors[k]; !ok {
|
||||
t.Errorf("missing expected PEFT key: %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify shapes are transposed: lora_a (4,2) → (2,4), lora_b (2,4) → (4,2).
|
||||
loraAInfo := peftTensors["base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"]
|
||||
if loraAInfo.Shape[0] != 2 || loraAInfo.Shape[1] != 4 {
|
||||
t.Errorf("lora_A shape = %v, want [2, 4]", loraAInfo.Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadWriteSafetensorsRoundtrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.safetensors")
|
||||
|
||||
original := map[string]safetensorsTensorInfo{
|
||||
"weight_a": {Dtype: "F32", Shape: []int{2, 3}},
|
||||
}
|
||||
data := map[string][]byte{
|
||||
"weight_a": make([]byte, 2*3*4),
|
||||
}
|
||||
for i := 0; i < 6; i++ {
|
||||
binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i)))
|
||||
}
|
||||
|
||||
if err := writeSafetensors(path, original, data); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
readTensors, readData, err := readSafetensors(path)
|
||||
if err != nil {
|
||||
t.Fatalf("read: %v", err)
|
||||
}
|
||||
|
||||
if len(readTensors) != 1 {
|
||||
t.Fatalf("expected 1 tensor, got %d", len(readTensors))
|
||||
}
|
||||
|
||||
info := readTensors["weight_a"]
|
||||
if info.Dtype != "F32" {
|
||||
t.Errorf("dtype = %s, want F32", info.Dtype)
|
||||
}
|
||||
if info.Shape[0] != 2 || info.Shape[1] != 3 {
|
||||
t.Errorf("shape = %v, want [2, 3]", info.Shape)
|
||||
}
|
||||
|
||||
got := getTensorData(info, readData)
|
||||
if len(got) != 24 {
|
||||
t.Errorf("data length = %d, want 24", len(got))
|
||||
}
|
||||
}
|
||||
126
pkg/lem/metrics.go
Normal file
126
pkg/lem/metrics.go
Normal file
|
|
@ -0,0 +1,126 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const targetTotal = 15000
|
||||
|
||||
// RunMetrics is the CLI entry point for the metrics command.
|
||||
// Reads golden set stats from DuckDB and pushes them to InfluxDB as
|
||||
// golden_set_stats, golden_set_domain, and golden_set_voice measurements.
|
||||
func RunMetrics(args []string) {
|
||||
fs := flag.NewFlagSet("metrics", flag.ExitOnError)
|
||||
|
||||
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
|
||||
influxURL := fs.String("influx", "", "InfluxDB URL")
|
||||
influxDB := fs.String("influx-db", "", "InfluxDB database name")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *dbPath == "" {
|
||||
*dbPath = os.Getenv("LEM_DB")
|
||||
}
|
||||
if *dbPath == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required (path to DuckDB file)")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
db, err := OpenDB(*dbPath)
|
||||
if err != nil {
|
||||
log.Fatalf("open db: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Query overall stats.
|
||||
var total, domains, voices int
|
||||
var avgGenTime, avgChars float64
|
||||
|
||||
err = db.conn.QueryRow(`
|
||||
SELECT count(*), count(DISTINCT domain), count(DISTINCT voice),
|
||||
coalesce(avg(gen_time), 0), coalesce(avg(char_count), 0)
|
||||
FROM golden_set
|
||||
`).Scan(&total, &domains, &voices, &avgGenTime, &avgChars)
|
||||
if err != nil {
|
||||
log.Fatalf("query golden_set stats: %v", err)
|
||||
}
|
||||
|
||||
if total == 0 {
|
||||
fmt.Println("No golden set data in DuckDB.")
|
||||
return
|
||||
}
|
||||
|
||||
nowNs := time.Now().UTC().UnixNano()
|
||||
pct := float64(total) / float64(targetTotal) * 100.0
|
||||
|
||||
var lines []string
|
||||
|
||||
// Overall stats measurement.
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_stats total_examples=%di,domains=%di,voices=%di,avg_gen_time=%.2f,avg_response_chars=%.0f,completion_pct=%.1f %d",
|
||||
total, domains, voices, avgGenTime, avgChars, pct, nowNs,
|
||||
))
|
||||
|
||||
// Per-domain stats.
|
||||
domainRows, err := db.conn.Query(`
|
||||
SELECT domain, count(*) AS n, avg(gen_time) AS avg_t
|
||||
FROM golden_set GROUP BY domain
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query domains: %v", err)
|
||||
}
|
||||
domainCount := 0
|
||||
for domainRows.Next() {
|
||||
var domain string
|
||||
var n int
|
||||
var avgT float64
|
||||
if err := domainRows.Scan(&domain, &n, &avgT); err != nil {
|
||||
log.Fatalf("scan domain row: %v", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d",
|
||||
escapeLp(domain), n, avgT, nowNs,
|
||||
))
|
||||
domainCount++
|
||||
}
|
||||
domainRows.Close()
|
||||
|
||||
// Per-voice stats.
|
||||
voiceRows, err := db.conn.Query(`
|
||||
SELECT voice, count(*) AS n, avg(char_count) AS avg_c, avg(gen_time) AS avg_t
|
||||
FROM golden_set GROUP BY voice
|
||||
`)
|
||||
if err != nil {
|
||||
log.Fatalf("query voices: %v", err)
|
||||
}
|
||||
voiceCount := 0
|
||||
for voiceRows.Next() {
|
||||
var voice string
|
||||
var n int
|
||||
var avgC, avgT float64
|
||||
if err := voiceRows.Scan(&voice, &n, &avgC, &avgT); err != nil {
|
||||
log.Fatalf("scan voice row: %v", err)
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf(
|
||||
"golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d",
|
||||
escapeLp(voice), n, avgC, avgT, nowNs,
|
||||
))
|
||||
voiceCount++
|
||||
}
|
||||
voiceRows.Close()
|
||||
|
||||
// Write to InfluxDB.
|
||||
influx := NewInfluxClient(*influxURL, *influxDB)
|
||||
if err := influx.WriteLp(lines); err != nil {
|
||||
log.Fatalf("write metrics: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Wrote metrics to InfluxDB: %d examples, %d domains, %d voices (%d points)\n",
|
||||
total, domainCount, voiceCount, len(lines))
|
||||
}
|
||||
162
pkg/lem/parquet.go
Normal file
162
pkg/lem/parquet.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/parquet-go/parquet-go"
|
||||
)
|
||||
|
||||
// ParquetRow is the schema for exported Parquet files.
|
||||
type ParquetRow struct {
|
||||
Prompt string `parquet:"prompt"`
|
||||
Response string `parquet:"response"`
|
||||
System string `parquet:"system"`
|
||||
Messages string `parquet:"messages"`
|
||||
}
|
||||
|
||||
// RunParquet is the CLI entry point for the parquet command.
|
||||
// Reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) and
|
||||
// writes Parquet files with snappy compression for HuggingFace datasets.
|
||||
func RunParquet(args []string) {
|
||||
fs := flag.NewFlagSet("parquet", flag.ExitOnError)
|
||||
|
||||
trainingDir := fs.String("input", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)")
|
||||
outputDir := fs.String("output", "", "Output directory for Parquet files (defaults to input/parquet)")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *trainingDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required (directory with JSONL splits)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *outputDir == "" {
|
||||
*outputDir = filepath.Join(*trainingDir, "parquet")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(*outputDir, 0755); err != nil {
|
||||
log.Fatalf("create output dir: %v", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Exporting Parquet from %s → %s\n", *trainingDir, *outputDir)
|
||||
|
||||
total := 0
|
||||
for _, split := range []string{"train", "valid", "test"} {
|
||||
jsonlPath := filepath.Join(*trainingDir, split+".jsonl")
|
||||
if _, err := os.Stat(jsonlPath); os.IsNotExist(err) {
|
||||
fmt.Printf(" Skip: %s.jsonl not found\n", split)
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := exportSplitParquet(jsonlPath, *outputDir, split)
|
||||
if err != nil {
|
||||
log.Fatalf("export %s: %v", split, err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
fmt.Printf("\nTotal: %d rows exported\n", total)
|
||||
}
|
||||
|
||||
// exportSplitParquet reads a JSONL file and writes a Parquet file for the split.
|
||||
func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
|
||||
f, err := os.Open(jsonlPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open %s: %w", jsonlPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var rows []ParquetRow
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
text := strings.TrimSpace(scanner.Text())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(text), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var prompt, response, system string
|
||||
for _, m := range data.Messages {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
if prompt == "" {
|
||||
prompt = m.Content
|
||||
}
|
||||
case "assistant":
|
||||
if response == "" {
|
||||
response = m.Content
|
||||
}
|
||||
case "system":
|
||||
if system == "" {
|
||||
system = m.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgsJSON, _ := json.Marshal(data.Messages)
|
||||
rows = append(rows, ParquetRow{
|
||||
Prompt: prompt,
|
||||
Response: response,
|
||||
System: system,
|
||||
Messages: string(msgsJSON),
|
||||
})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return 0, fmt.Errorf("scan %s: %w", jsonlPath, err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
fmt.Printf(" Skip: %s — no data\n", split)
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
outPath := filepath.Join(outputDir, split+".parquet")
|
||||
|
||||
out, err := os.Create(outPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create %s: %w", outPath, err)
|
||||
}
|
||||
|
||||
writer := parquet.NewGenericWriter[ParquetRow](out,
|
||||
parquet.Compression(&parquet.Snappy),
|
||||
)
|
||||
|
||||
if _, err := writer.Write(rows); err != nil {
|
||||
out.Close()
|
||||
return 0, fmt.Errorf("write parquet rows: %w", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
out.Close()
|
||||
return 0, fmt.Errorf("close parquet writer: %w", err)
|
||||
}
|
||||
|
||||
if err := out.Close(); err != nil {
|
||||
return 0, fmt.Errorf("close file: %w", err)
|
||||
}
|
||||
|
||||
info, _ := os.Stat(outPath)
|
||||
sizeMB := float64(info.Size()) / 1024 / 1024
|
||||
fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, len(rows), sizeMB)
|
||||
|
||||
return len(rows), nil
|
||||
}
|
||||
143
pkg/lem/parquet_test.go
Normal file
143
pkg/lem/parquet_test.go
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/parquet-go/parquet-go"
|
||||
)
|
||||
|
||||
func TestExportSplitParquet(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
inputPath := filepath.Join(dir, "train.jsonl")
|
||||
outputDir := filepath.Join(dir, "output")
|
||||
os.MkdirAll(outputDir, 0755)
|
||||
|
||||
// Write test JSONL.
|
||||
convs := []TrainingExample{
|
||||
{Messages: []ChatMessage{
|
||||
{Role: "user", Content: "What is wisdom?"},
|
||||
{Role: "assistant", Content: "The application of understanding."},
|
||||
}},
|
||||
{Messages: []ChatMessage{
|
||||
{Role: "system", Content: "You are helpful."},
|
||||
{Role: "user", Content: "Tell me about ethics."},
|
||||
{Role: "assistant", Content: "Ethics concerns right action."},
|
||||
}},
|
||||
}
|
||||
|
||||
f, _ := os.Create(inputPath)
|
||||
for _, c := range convs {
|
||||
data, _ := json.Marshal(c)
|
||||
f.Write(data)
|
||||
f.WriteString("\n")
|
||||
}
|
||||
f.Close()
|
||||
|
||||
n, err := exportSplitParquet(inputPath, outputDir, "train")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if n != 2 {
|
||||
t.Errorf("expected 2 rows, got %d", n)
|
||||
}
|
||||
|
||||
// Verify Parquet file exists and is readable.
|
||||
outPath := filepath.Join(outputDir, "train.parquet")
|
||||
pf, err := os.Open(outPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open parquet: %v", err)
|
||||
}
|
||||
defer pf.Close()
|
||||
|
||||
info, _ := pf.Stat()
|
||||
reader := parquet.NewGenericReader[ParquetRow](pf)
|
||||
defer reader.Close()
|
||||
|
||||
rows := make([]ParquetRow, 10)
|
||||
read, err := reader.Read(rows)
|
||||
if err != nil && err != io.EOF {
|
||||
t.Fatalf("read parquet: %v", err)
|
||||
}
|
||||
if read != 2 {
|
||||
t.Errorf("expected 2 rows in parquet, got %d", read)
|
||||
}
|
||||
|
||||
if rows[0].Prompt != "What is wisdom?" {
|
||||
t.Errorf("unexpected prompt: %s", rows[0].Prompt)
|
||||
}
|
||||
if rows[0].Response != "The application of understanding." {
|
||||
t.Errorf("unexpected response: %s", rows[0].Response)
|
||||
}
|
||||
if rows[1].System != "You are helpful." {
|
||||
t.Errorf("expected system message, got: %s", rows[1].System)
|
||||
}
|
||||
|
||||
if info.Size() == 0 {
|
||||
t.Error("parquet file is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportSplitParquetEmpty(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
inputPath := filepath.Join(dir, "empty.jsonl")
|
||||
outputDir := filepath.Join(dir, "output")
|
||||
os.MkdirAll(outputDir, 0755)
|
||||
|
||||
// Write empty JSONL.
|
||||
os.WriteFile(inputPath, []byte("\n\n"), 0644)
|
||||
|
||||
n, err := exportSplitParquet(inputPath, outputDir, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("expected 0 rows for empty file, got %d", n)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportSplitParquetMessages(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
inputPath := filepath.Join(dir, "valid.jsonl")
|
||||
outputDir := filepath.Join(dir, "output")
|
||||
os.MkdirAll(outputDir, 0755)
|
||||
|
||||
conv := TrainingExample{Messages: []ChatMessage{
|
||||
{Role: "user", Content: "hi"},
|
||||
{Role: "assistant", Content: "hello"},
|
||||
}}
|
||||
|
||||
f, _ := os.Create(inputPath)
|
||||
data, _ := json.Marshal(conv)
|
||||
f.Write(data)
|
||||
f.WriteString("\n")
|
||||
f.Close()
|
||||
|
||||
n, err := exportSplitParquet(inputPath, outputDir, "valid")
|
||||
if err != nil {
|
||||
t.Fatalf("export: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("expected 1 row, got %d", n)
|
||||
}
|
||||
|
||||
// Verify messages field contains valid JSON.
|
||||
pf, _ := os.Open(filepath.Join(outputDir, "valid.parquet"))
|
||||
defer pf.Close()
|
||||
reader := parquet.NewGenericReader[ParquetRow](pf)
|
||||
defer reader.Close()
|
||||
|
||||
rows := make([]ParquetRow, 1)
|
||||
reader.Read(rows)
|
||||
|
||||
var msgs []ChatMessage
|
||||
if err := json.Unmarshal([]byte(rows[0].Messages), &msgs); err != nil {
|
||||
t.Fatalf("parse messages JSON: %v", err)
|
||||
}
|
||||
if len(msgs) != 2 {
|
||||
t.Errorf("expected 2 messages in JSON, got %d", len(msgs))
|
||||
}
|
||||
}
|
||||
138
pkg/lem/publish.go
Normal file
138
pkg/lem/publish.go
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
package lem
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RunPublish is the CLI entry point for the publish command.
|
||||
// Pushes Parquet files and an optional dataset card to HuggingFace.
|
||||
func RunPublish(args []string) {
|
||||
fs := flag.NewFlagSet("publish", flag.ExitOnError)
|
||||
|
||||
inputDir := fs.String("input", "", "Directory containing Parquet files (required)")
|
||||
repoID := fs.String("repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID")
|
||||
public := fs.Bool("public", false, "Make dataset public")
|
||||
token := fs.String("token", "", "HuggingFace API token (defaults to HF_TOKEN env)")
|
||||
dryRun := fs.Bool("dry-run", false, "Show what would be uploaded without uploading")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
log.Fatalf("parse flags: %v", err)
|
||||
}
|
||||
|
||||
if *inputDir == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --input is required (directory with Parquet files)")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hfToken := *token
|
||||
if hfToken == "" {
|
||||
hfToken = os.Getenv("HF_TOKEN")
|
||||
}
|
||||
if hfToken == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".huggingface", "token"))
|
||||
if err == nil {
|
||||
hfToken = strings.TrimSpace(string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hfToken == "" && !*dryRun {
|
||||
fmt.Fprintln(os.Stderr, "error: HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
splits := []string{"train", "valid", "test"}
|
||||
type uploadEntry struct {
|
||||
local string
|
||||
remote string
|
||||
}
|
||||
var filesToUpload []uploadEntry
|
||||
|
||||
for _, split := range splits {
|
||||
path := filepath.Join(*inputDir, split+".parquet")
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
filesToUpload = append(filesToUpload, uploadEntry{path, fmt.Sprintf("data/%s.parquet", split)})
|
||||
}
|
||||
|
||||
// Check for dataset card in parent directory.
|
||||
cardPath := filepath.Join(*inputDir, "..", "dataset_card.md")
|
||||
if _, err := os.Stat(cardPath); err == nil {
|
||||
filesToUpload = append(filesToUpload, uploadEntry{cardPath, "README.md"})
|
||||
}
|
||||
|
||||
if len(filesToUpload) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "error: no Parquet files found in input directory")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *dryRun {
|
||||
fmt.Printf("Dry run: would publish to %s\n", *repoID)
|
||||
if *public {
|
||||
fmt.Println(" Visibility: public")
|
||||
} else {
|
||||
fmt.Println(" Visibility: private")
|
||||
}
|
||||
for _, f := range filesToUpload {
|
||||
info, _ := os.Stat(f.local)
|
||||
sizeMB := float64(info.Size()) / 1024 / 1024
|
||||
fmt.Printf(" %s → %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", *repoID)
|
||||
|
||||
for _, f := range filesToUpload {
|
||||
if err := uploadFileToHF(hfToken, *repoID, f.local, f.remote); err != nil {
|
||||
log.Fatalf("upload %s: %v", f.local, err)
|
||||
}
|
||||
fmt.Printf(" Uploaded %s → %s\n", filepath.Base(f.local), f.remote)
|
||||
}
|
||||
|
||||
fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", *repoID)
|
||||
}
|
||||
|
||||
// uploadFileToHF uploads a file to a HuggingFace dataset repo via the Hub API.
|
||||
func uploadFileToHF(token, repoID, localPath, remotePath string) error {
|
||||
data, err := os.ReadFile(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", localPath, err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
client := &http.Client{Timeout: 120 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 300 {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("upload failed: HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue