1
0
Fork 0
forked from lthn/LEM

feat: add parquet, publish, metrics, convert commands

- `lem parquet` — export JSONL training splits to Parquet (parquet-go)
- `lem publish` — push Parquet files to HuggingFace dataset repo
- `lem metrics` — push DuckDB golden set stats to InfluxDB
- `lem convert` — MLX LoRA adapter → HuggingFace PEFT format
  (pure Go safetensors read/write/transpose, no PyTorch needed)

Dependencies added: parquet-go, go-huggingface, go-rocm, go-pytorch, gotch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-15 17:05:08 +00:00
parent 0afa5e9147
commit 4eaf1bfb39
No known key found for this signature in database
GPG key ID: AF404715446AEB41
9 changed files with 1187 additions and 7 deletions

14
go.mod
View file

@ -3,21 +3,29 @@ module forge.lthn.ai/lthn/lem
go 1.25.6
require (
github.com/marcboeker/go-duckdb v1.8.5
github.com/parquet-go/parquet-go v0.27.0
)
require (
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/apache/arrow-go/v18 v18.1.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/google/flatbuffers v25.1.24+incompatible // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hupe1980/go-huggingface v0.0.15 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/marcboeker/go-duckdb v1.8.5 // indirect
github.com/parquet-go/bitpack v1.0.0 // indirect
github.com/parquet-go/jsonlite v1.0.0 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/twpayne/go-geom v1.6.1 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect
golang.org/x/mod v0.22.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/tools v0.29.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/protobuf v1.36.1 // indirect
)

52
go.sum
View file

@ -1,23 +1,61 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY=
github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k=
github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc=
github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y=
github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0=
github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE=
github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o=
github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hupe1980/go-huggingface v0.0.15 h1:tTWmUGGunC/BYz4hrwS8SSVtMYVYjceG2uhL8HxeXvw=
github.com/hupe1980/go-huggingface v0.0.15/go.mod h1:IRvsik3+b9BJyw9hCfw1arI6gDObcVto1UA8f3kt8mM=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4=
github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0=
github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs=
github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI=
github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE=
github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA=
github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs=
github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU=
github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0=
github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g=
github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4=
github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ=
github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc=
@ -26,9 +64,15 @@ golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4=
golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE=
golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0=
gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o=
google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk=
google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

12
main.go
View file

@ -21,6 +21,10 @@ Commands:
expand Generate expansion responses via trained LEM model
conv Generate conversational training data
ingest Ingest benchmark data into InfluxDB
parquet Export JSONL training splits to Parquet for HuggingFace
publish Push Parquet files to HuggingFace dataset repo
metrics Push DuckDB golden set stats to InfluxDB
convert Convert MLX LoRA adapter to HuggingFace PEFT format
`
func main() {
@ -46,6 +50,14 @@ func main() {
lem.RunConv(os.Args[2:])
case "ingest":
lem.RunIngest(os.Args[2:])
case "parquet":
lem.RunParquet(os.Args[2:])
case "publish":
lem.RunPublish(os.Args[2:])
case "metrics":
lem.RunMetrics(os.Args[2:])
case "convert":
lem.RunConvert(os.Args[2:])
default:
fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage)
os.Exit(1)

349
pkg/lem/convert.go Normal file
View file

@ -0,0 +1,349 @@
package lem
import (
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"log"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
)
// RunConvert is the CLI entry point for the convert command.
// Converts MLX LoRA adapters to HuggingFace PEFT format:
// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight
// - Transpose: MLX (in, rank) → PEFT (rank, in)
// - Config generation: adapter_config.json with lora_alpha = scale × rank
func RunConvert(args []string) {
fs := flag.NewFlagSet("convert", flag.ExitOnError)
safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)")
configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)")
outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter")
baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if *safetensorsPath == "" || *configPath == "" {
fmt.Fprintln(os.Stderr, "error: --input and --config are required")
fs.Usage()
os.Exit(1)
}
if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil {
log.Fatalf("convert: %v", err)
}
fmt.Printf("Converted to: %s\n", *outputDir)
}
var (
loraARe = regexp.MustCompile(`\.lora_a$`)
loraBRe = regexp.MustCompile(`\.lora_b$`)
layerRe = regexp.MustCompile(`layers\.(\d+)`)
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
)
// renameMLXKey converts an MLX tensor key to PEFT format.
func renameMLXKey(mlxKey string) string {
key := mlxKey
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
key = "base_model.model." + key
return key
}
// safetensorsHeader represents the header of a safetensors file.
type safetensorsHeader struct {
Metadata map[string]string `json:"__metadata__,omitempty"`
Tensors map[string]safetensorsTensorInfo `json:"-"`
}
type safetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// readSafetensors reads a safetensors file and returns tensor name→data+info pairs.
func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, fmt.Errorf("read file: %w", err)
}
if len(data) < 8 {
return nil, nil, fmt.Errorf("file too small")
}
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
if 8+headerSize > len(data) {
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
}
headerJSON := data[8 : 8+headerSize]
tensorData := data[8+headerSize:]
// Parse header as a generic map since tensors are top-level keys.
var rawHeader map[string]json.RawMessage
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
return nil, nil, fmt.Errorf("parse header: %w", err)
}
tensors := make(map[string]safetensorsTensorInfo)
for key, raw := range rawHeader {
if key == "__metadata__" {
continue
}
var info safetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
}
tensors[key] = info
}
return tensors, tensorData, nil
}
// getTensorData extracts raw bytes for a tensor from the data section.
func getTensorData(info safetensorsTensorInfo, allData []byte) []byte {
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
}
// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
func transposeFloat32(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*4 {
return data // size mismatch, return as-is
}
result := make([]byte, len(data))
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
srcOff := (r*cols + c) * 4
dstOff := (c*rows + r) * 4
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
}
}
return result
}
// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
func transposeFloat16(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*2 {
return data
}
result := make([]byte, len(data))
for r := 0; r < rows; r++ {
for c := 0; c < cols; c++ {
srcOff := (r*cols + c) * 2
dstOff := (c*rows + r) * 2
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
}
}
return result
}
// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
func transposeBFloat16(data []byte, rows, cols int) []byte {
return transposeFloat16(data, rows, cols) // same element size
}
// writeSafetensors writes tensors to a safetensors file.
func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error {
// Sort keys for deterministic output.
keys := make([]string, 0, len(tensors))
for k := range tensors {
keys = append(keys, k)
}
sort.Strings(keys)
// Compute offsets.
offset := 0
updatedTensors := make(map[string]safetensorsTensorInfo)
for _, k := range keys {
info := tensors[k]
data := tensorData[k]
info.DataOffsets = [2]int{offset, offset + len(data)}
updatedTensors[k] = info
offset += len(data)
}
// Build header JSON.
headerMap := make(map[string]interface{})
for k, info := range updatedTensors {
headerMap[k] = info
}
headerJSON, err := json.Marshal(headerMap)
if err != nil {
return fmt.Errorf("marshal header: %w", err)
}
// Write file: 8-byte header size + header JSON + tensor data.
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
headerSizeBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
if _, err := f.Write(headerSizeBuf); err != nil {
return err
}
if _, err := f.Write(headerJSON); err != nil {
return err
}
for _, k := range keys {
if _, err := f.Write(tensorData[k]); err != nil {
return err
}
}
return nil
}
// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format.
func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
// Read MLX tensors.
tensors, tensorData, err := readSafetensors(safetensorsPath)
if err != nil {
return fmt.Errorf("read safetensors: %w", err)
}
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
// Rename and transpose tensors.
peftTensors := make(map[string]safetensorsTensorInfo)
peftData := make(map[string][]byte)
for mlxKey, info := range tensors {
peftKey := renameMLXKey(mlxKey)
data := getTensorData(info, tensorData)
// Transpose: swap shape and transpose data.
if len(info.Shape) == 2 {
rows, cols := info.Shape[0], info.Shape[1]
switch info.Dtype {
case "F32":
data = transposeFloat32(data, rows, cols)
case "F16":
data = transposeFloat16(data, rows, cols)
case "BF16":
data = transposeBFloat16(data, rows, cols)
}
info.Shape = []int{cols, rows}
}
peftTensors[peftKey] = info
peftData[peftKey] = data
}
// Write PEFT safetensors.
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil {
return fmt.Errorf("write safetensors: %w", err)
}
// Read MLX config for LoRA parameters.
cfgData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
var mlxConfig struct {
LoraParameters struct {
Rank int `json:"rank"`
Scale float64 `json:"scale"`
Dropout float64 `json:"dropout"`
} `json:"lora_parameters"`
}
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
return fmt.Errorf("parse config: %w", err)
}
rank := mlxConfig.LoraParameters.Rank
if rank == 0 {
rank = 8
}
scale := mlxConfig.LoraParameters.Scale
if scale == 0 {
scale = 20.0
}
// Determine target modules from tensor keys.
modules := make(map[string]bool)
layers := make(map[int]bool)
for k := range tensors {
if m := moduleRe.FindStringSubmatch(k); m != nil {
parts := strings.Split(m[1], ".")
modules[parts[len(parts)-1]] = true
}
if m := layerRe.FindStringSubmatch(k); m != nil {
n, _ := strconv.Atoi(m[1])
layers[n] = true
}
}
sortedModules := make([]string, 0, len(modules))
for m := range modules {
sortedModules = append(sortedModules, m)
}
sort.Strings(sortedModules)
sortedLayers := make([]int, 0, len(layers))
for l := range layers {
sortedLayers = append(sortedLayers, l)
}
sort.Ints(sortedLayers)
// Write PEFT adapter_config.json.
peftConfig := map[string]interface{}{
"auto_mapping": nil,
"base_model_name_or_path": baseModelName,
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layers_pattern": nil,
"layers_to_transform": sortedLayers,
"lora_alpha": math.Round(scale * float64(rank)),
"lora_dropout": mlxConfig.LoraParameters.Dropout,
"modules_to_save": nil,
"peft_type": "LORA",
"r": rank,
"revision": nil,
"target_modules": sortedModules,
"task_type": "CAUSAL_LM",
}
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
if err != nil {
return fmt.Errorf("marshal peft config: %w", err)
}
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
return fmt.Errorf("write adapter_config.json: %w", err)
}
log.Printf("converted %d tensors, %d layers, target modules: %v",
len(peftTensors), len(sortedLayers), sortedModules)
return nil
}

198
pkg/lem/convert_test.go Normal file
View file

@ -0,0 +1,198 @@
package lem
import (
"encoding/binary"
"encoding/json"
"math"
"os"
"path/filepath"
"testing"
)
func TestRenameMLXKey(t *testing.T) {
tests := []struct {
input string
want string
}{
{
"model.layers.12.self_attn.q_proj.lora_a",
"base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight",
},
{
"model.layers.0.self_attn.v_proj.lora_b",
"base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight",
},
{
"model.layers.5.mlp.gate_proj.lora_a",
"base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight",
},
}
for _, tt := range tests {
got := renameMLXKey(tt.input)
if got != tt.want {
t.Errorf("renameMLXKey(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestTransposeFloat32(t *testing.T) {
// 2x3 matrix: [[1, 2, 3], [4, 5, 6]]
data := make([]byte, 2*3*4)
for i, v := range []float32{1, 2, 3, 4, 5, 6} {
binary.LittleEndian.PutUint32(data[i*4:], math.Float32bits(v))
}
result := transposeFloat32(data, 2, 3)
// Expected: 3x2 matrix: [[1, 4], [2, 5], [3, 6]]
expected := []float32{1, 4, 2, 5, 3, 6}
for i, want := range expected {
got := math.Float32frombits(binary.LittleEndian.Uint32(result[i*4:]))
if got != want {
t.Errorf("result[%d] = %f, want %f", i, got, want)
}
}
}
func TestConvertMLXtoPEFT(t *testing.T) {
dir := t.TempDir()
// Create a minimal MLX safetensors file with one lora_a and one lora_b tensor.
// Shape: lora_a is (in=4, rank=2), lora_b is (rank=2, out=4)
tensors := map[string]safetensorsTensorInfo{
"model.layers.0.self_attn.q_proj.lora_a": {Dtype: "F32", Shape: []int{4, 2}},
"model.layers.0.self_attn.q_proj.lora_b": {Dtype: "F32", Shape: []int{2, 4}},
}
// Create tensor data: 4x2=8 floats and 2x4=8 floats.
loraAData := make([]byte, 4*2*4)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1)))
}
loraBData := make([]byte, 2*4*4)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i)))
}
tensorData := make(map[string][]byte)
tensorData["model.layers.0.self_attn.q_proj.lora_a"] = loraAData
tensorData["model.layers.0.self_attn.q_proj.lora_b"] = loraBData
sfPath := filepath.Join(dir, "adapters.safetensors")
if err := writeSafetensors(sfPath, tensors, tensorData); err != nil {
t.Fatalf("write test safetensors: %v", err)
}
// Create MLX config.
mlxConfig := map[string]interface{}{
"lora_parameters": map[string]interface{}{
"rank": 8,
"scale": 20.0,
"dropout": 0.0,
},
}
cfgData, _ := json.Marshal(mlxConfig)
cfgPath := filepath.Join(dir, "adapter_config.json")
os.WriteFile(cfgPath, cfgData, 0644)
// Convert.
outputDir := filepath.Join(dir, "peft_output")
if err := convertMLXtoPEFT(sfPath, cfgPath, outputDir, "test-model"); err != nil {
t.Fatalf("convert: %v", err)
}
// Check output files exist.
if _, err := os.Stat(filepath.Join(outputDir, "adapter_model.safetensors")); err != nil {
t.Error("missing adapter_model.safetensors")
}
if _, err := os.Stat(filepath.Join(outputDir, "adapter_config.json")); err != nil {
t.Error("missing adapter_config.json")
}
// Read and verify PEFT config.
peftCfgData, err := os.ReadFile(filepath.Join(outputDir, "adapter_config.json"))
if err != nil {
t.Fatalf("read peft config: %v", err)
}
var peftConfig map[string]interface{}
if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil {
t.Fatalf("parse peft config: %v", err)
}
if peftConfig["peft_type"] != "LORA" {
t.Errorf("peft_type = %v, want LORA", peftConfig["peft_type"])
}
if peftConfig["base_model_name_or_path"] != "test-model" {
t.Errorf("base_model = %v, want test-model", peftConfig["base_model_name_or_path"])
}
// Check that lora_alpha = scale * rank = 20 * 8 = 160.
if alpha, ok := peftConfig["lora_alpha"].(float64); !ok || alpha != 160 {
t.Errorf("lora_alpha = %v, want 160", peftConfig["lora_alpha"])
}
// Verify converted safetensors has PEFT-format keys.
peftTensors, _, err := readSafetensors(filepath.Join(outputDir, "adapter_model.safetensors"))
if err != nil {
t.Fatalf("read peft safetensors: %v", err)
}
expectedKeys := []string{
"base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight",
"base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight",
}
for _, k := range expectedKeys {
if _, ok := peftTensors[k]; !ok {
t.Errorf("missing expected PEFT key: %s", k)
}
}
// Verify shapes are transposed: lora_a (4,2) → (2,4), lora_b (2,4) → (4,2).
loraAInfo := peftTensors["base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"]
if loraAInfo.Shape[0] != 2 || loraAInfo.Shape[1] != 4 {
t.Errorf("lora_A shape = %v, want [2, 4]", loraAInfo.Shape)
}
}
func TestReadWriteSafetensorsRoundtrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.safetensors")
original := map[string]safetensorsTensorInfo{
"weight_a": {Dtype: "F32", Shape: []int{2, 3}},
}
data := map[string][]byte{
"weight_a": make([]byte, 2*3*4),
}
for i := 0; i < 6; i++ {
binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i)))
}
if err := writeSafetensors(path, original, data); err != nil {
t.Fatalf("write: %v", err)
}
readTensors, readData, err := readSafetensors(path)
if err != nil {
t.Fatalf("read: %v", err)
}
if len(readTensors) != 1 {
t.Fatalf("expected 1 tensor, got %d", len(readTensors))
}
info := readTensors["weight_a"]
if info.Dtype != "F32" {
t.Errorf("dtype = %s, want F32", info.Dtype)
}
if info.Shape[0] != 2 || info.Shape[1] != 3 {
t.Errorf("shape = %v, want [2, 3]", info.Shape)
}
got := getTensorData(info, readData)
if len(got) != 24 {
t.Errorf("data length = %d, want 24", len(got))
}
}

126
pkg/lem/metrics.go Normal file
View file

@ -0,0 +1,126 @@
package lem
import (
"flag"
"fmt"
"log"
"os"
"time"
)
const targetTotal = 15000
// RunMetrics is the CLI entry point for the metrics command.
// Reads golden set stats from DuckDB and pushes them to InfluxDB as
// golden_set_stats, golden_set_domain, and golden_set_voice measurements.
func RunMetrics(args []string) {
fs := flag.NewFlagSet("metrics", flag.ExitOnError)
dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)")
influxURL := fs.String("influx", "", "InfluxDB URL")
influxDB := fs.String("influx-db", "", "InfluxDB database name")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if *dbPath == "" {
*dbPath = os.Getenv("LEM_DB")
}
if *dbPath == "" {
fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required (path to DuckDB file)")
os.Exit(1)
}
db, err := OpenDB(*dbPath)
if err != nil {
log.Fatalf("open db: %v", err)
}
defer db.Close()
// Query overall stats.
var total, domains, voices int
var avgGenTime, avgChars float64
err = db.conn.QueryRow(`
SELECT count(*), count(DISTINCT domain), count(DISTINCT voice),
coalesce(avg(gen_time), 0), coalesce(avg(char_count), 0)
FROM golden_set
`).Scan(&total, &domains, &voices, &avgGenTime, &avgChars)
if err != nil {
log.Fatalf("query golden_set stats: %v", err)
}
if total == 0 {
fmt.Println("No golden set data in DuckDB.")
return
}
nowNs := time.Now().UTC().UnixNano()
pct := float64(total) / float64(targetTotal) * 100.0
var lines []string
// Overall stats measurement.
lines = append(lines, fmt.Sprintf(
"golden_set_stats total_examples=%di,domains=%di,voices=%di,avg_gen_time=%.2f,avg_response_chars=%.0f,completion_pct=%.1f %d",
total, domains, voices, avgGenTime, avgChars, pct, nowNs,
))
// Per-domain stats.
domainRows, err := db.conn.Query(`
SELECT domain, count(*) AS n, avg(gen_time) AS avg_t
FROM golden_set GROUP BY domain
`)
if err != nil {
log.Fatalf("query domains: %v", err)
}
domainCount := 0
for domainRows.Next() {
var domain string
var n int
var avgT float64
if err := domainRows.Scan(&domain, &n, &avgT); err != nil {
log.Fatalf("scan domain row: %v", err)
}
lines = append(lines, fmt.Sprintf(
"golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d",
escapeLp(domain), n, avgT, nowNs,
))
domainCount++
}
domainRows.Close()
// Per-voice stats.
voiceRows, err := db.conn.Query(`
SELECT voice, count(*) AS n, avg(char_count) AS avg_c, avg(gen_time) AS avg_t
FROM golden_set GROUP BY voice
`)
if err != nil {
log.Fatalf("query voices: %v", err)
}
voiceCount := 0
for voiceRows.Next() {
var voice string
var n int
var avgC, avgT float64
if err := voiceRows.Scan(&voice, &n, &avgC, &avgT); err != nil {
log.Fatalf("scan voice row: %v", err)
}
lines = append(lines, fmt.Sprintf(
"golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d",
escapeLp(voice), n, avgC, avgT, nowNs,
))
voiceCount++
}
voiceRows.Close()
// Write to InfluxDB.
influx := NewInfluxClient(*influxURL, *influxDB)
if err := influx.WriteLp(lines); err != nil {
log.Fatalf("write metrics: %v", err)
}
fmt.Printf("Wrote metrics to InfluxDB: %d examples, %d domains, %d voices (%d points)\n",
total, domainCount, voiceCount, len(lines))
}

162
pkg/lem/parquet.go Normal file
View file

@ -0,0 +1,162 @@
package lem
import (
"bufio"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"strings"
"github.com/parquet-go/parquet-go"
)
// ParquetRow is the schema for exported Parquet files.
type ParquetRow struct {
Prompt string `parquet:"prompt"`
Response string `parquet:"response"`
System string `parquet:"system"`
Messages string `parquet:"messages"`
}
// RunParquet is the CLI entry point for the parquet command.
// Reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) and
// writes Parquet files with snappy compression for HuggingFace datasets.
func RunParquet(args []string) {
fs := flag.NewFlagSet("parquet", flag.ExitOnError)
trainingDir := fs.String("input", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)")
outputDir := fs.String("output", "", "Output directory for Parquet files (defaults to input/parquet)")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if *trainingDir == "" {
fmt.Fprintln(os.Stderr, "error: --input is required (directory with JSONL splits)")
fs.Usage()
os.Exit(1)
}
if *outputDir == "" {
*outputDir = filepath.Join(*trainingDir, "parquet")
}
if err := os.MkdirAll(*outputDir, 0755); err != nil {
log.Fatalf("create output dir: %v", err)
}
fmt.Printf("Exporting Parquet from %s → %s\n", *trainingDir, *outputDir)
total := 0
for _, split := range []string{"train", "valid", "test"} {
jsonlPath := filepath.Join(*trainingDir, split+".jsonl")
if _, err := os.Stat(jsonlPath); os.IsNotExist(err) {
fmt.Printf(" Skip: %s.jsonl not found\n", split)
continue
}
n, err := exportSplitParquet(jsonlPath, *outputDir, split)
if err != nil {
log.Fatalf("export %s: %v", split, err)
}
total += n
}
fmt.Printf("\nTotal: %d rows exported\n", total)
}
// exportSplitParquet reads a JSONL file and writes a Parquet file for the split.
func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
f, err := os.Open(jsonlPath)
if err != nil {
return 0, fmt.Errorf("open %s: %w", jsonlPath, err)
}
defer f.Close()
var rows []ParquetRow
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" {
continue
}
var data struct {
Messages []ChatMessage `json:"messages"`
}
if err := json.Unmarshal([]byte(text), &data); err != nil {
continue
}
var prompt, response, system string
for _, m := range data.Messages {
switch m.Role {
case "user":
if prompt == "" {
prompt = m.Content
}
case "assistant":
if response == "" {
response = m.Content
}
case "system":
if system == "" {
system = m.Content
}
}
}
msgsJSON, _ := json.Marshal(data.Messages)
rows = append(rows, ParquetRow{
Prompt: prompt,
Response: response,
System: system,
Messages: string(msgsJSON),
})
}
if err := scanner.Err(); err != nil {
return 0, fmt.Errorf("scan %s: %w", jsonlPath, err)
}
if len(rows) == 0 {
fmt.Printf(" Skip: %s — no data\n", split)
return 0, nil
}
outPath := filepath.Join(outputDir, split+".parquet")
out, err := os.Create(outPath)
if err != nil {
return 0, fmt.Errorf("create %s: %w", outPath, err)
}
writer := parquet.NewGenericWriter[ParquetRow](out,
parquet.Compression(&parquet.Snappy),
)
if _, err := writer.Write(rows); err != nil {
out.Close()
return 0, fmt.Errorf("write parquet rows: %w", err)
}
if err := writer.Close(); err != nil {
out.Close()
return 0, fmt.Errorf("close parquet writer: %w", err)
}
if err := out.Close(); err != nil {
return 0, fmt.Errorf("close file: %w", err)
}
info, _ := os.Stat(outPath)
sizeMB := float64(info.Size()) / 1024 / 1024
fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, len(rows), sizeMB)
return len(rows), nil
}

143
pkg/lem/parquet_test.go Normal file
View file

@ -0,0 +1,143 @@
package lem
import (
"encoding/json"
"io"
"os"
"path/filepath"
"testing"
"github.com/parquet-go/parquet-go"
)
func TestExportSplitParquet(t *testing.T) {
dir := t.TempDir()
inputPath := filepath.Join(dir, "train.jsonl")
outputDir := filepath.Join(dir, "output")
os.MkdirAll(outputDir, 0755)
// Write test JSONL.
convs := []TrainingExample{
{Messages: []ChatMessage{
{Role: "user", Content: "What is wisdom?"},
{Role: "assistant", Content: "The application of understanding."},
}},
{Messages: []ChatMessage{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Tell me about ethics."},
{Role: "assistant", Content: "Ethics concerns right action."},
}},
}
f, _ := os.Create(inputPath)
for _, c := range convs {
data, _ := json.Marshal(c)
f.Write(data)
f.WriteString("\n")
}
f.Close()
n, err := exportSplitParquet(inputPath, outputDir, "train")
if err != nil {
t.Fatalf("export: %v", err)
}
if n != 2 {
t.Errorf("expected 2 rows, got %d", n)
}
// Verify Parquet file exists and is readable.
outPath := filepath.Join(outputDir, "train.parquet")
pf, err := os.Open(outPath)
if err != nil {
t.Fatalf("open parquet: %v", err)
}
defer pf.Close()
info, _ := pf.Stat()
reader := parquet.NewGenericReader[ParquetRow](pf)
defer reader.Close()
rows := make([]ParquetRow, 10)
read, err := reader.Read(rows)
if err != nil && err != io.EOF {
t.Fatalf("read parquet: %v", err)
}
if read != 2 {
t.Errorf("expected 2 rows in parquet, got %d", read)
}
if rows[0].Prompt != "What is wisdom?" {
t.Errorf("unexpected prompt: %s", rows[0].Prompt)
}
if rows[0].Response != "The application of understanding." {
t.Errorf("unexpected response: %s", rows[0].Response)
}
if rows[1].System != "You are helpful." {
t.Errorf("expected system message, got: %s", rows[1].System)
}
if info.Size() == 0 {
t.Error("parquet file is empty")
}
}
func TestExportSplitParquetEmpty(t *testing.T) {
dir := t.TempDir()
inputPath := filepath.Join(dir, "empty.jsonl")
outputDir := filepath.Join(dir, "output")
os.MkdirAll(outputDir, 0755)
// Write empty JSONL.
os.WriteFile(inputPath, []byte("\n\n"), 0644)
n, err := exportSplitParquet(inputPath, outputDir, "test")
if err != nil {
t.Fatalf("export: %v", err)
}
if n != 0 {
t.Errorf("expected 0 rows for empty file, got %d", n)
}
}
func TestExportSplitParquetMessages(t *testing.T) {
dir := t.TempDir()
inputPath := filepath.Join(dir, "valid.jsonl")
outputDir := filepath.Join(dir, "output")
os.MkdirAll(outputDir, 0755)
conv := TrainingExample{Messages: []ChatMessage{
{Role: "user", Content: "hi"},
{Role: "assistant", Content: "hello"},
}}
f, _ := os.Create(inputPath)
data, _ := json.Marshal(conv)
f.Write(data)
f.WriteString("\n")
f.Close()
n, err := exportSplitParquet(inputPath, outputDir, "valid")
if err != nil {
t.Fatalf("export: %v", err)
}
if n != 1 {
t.Errorf("expected 1 row, got %d", n)
}
// Verify messages field contains valid JSON.
pf, _ := os.Open(filepath.Join(outputDir, "valid.parquet"))
defer pf.Close()
reader := parquet.NewGenericReader[ParquetRow](pf)
defer reader.Close()
rows := make([]ParquetRow, 1)
reader.Read(rows)
var msgs []ChatMessage
if err := json.Unmarshal([]byte(rows[0].Messages), &msgs); err != nil {
t.Fatalf("parse messages JSON: %v", err)
}
if len(msgs) != 2 {
t.Errorf("expected 2 messages in JSON, got %d", len(msgs))
}
}

138
pkg/lem/publish.go Normal file
View file

@ -0,0 +1,138 @@
package lem
import (
"bytes"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
// RunPublish is the CLI entry point for the publish command.
// Pushes Parquet files and an optional dataset card to HuggingFace.
func RunPublish(args []string) {
fs := flag.NewFlagSet("publish", flag.ExitOnError)
inputDir := fs.String("input", "", "Directory containing Parquet files (required)")
repoID := fs.String("repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID")
public := fs.Bool("public", false, "Make dataset public")
token := fs.String("token", "", "HuggingFace API token (defaults to HF_TOKEN env)")
dryRun := fs.Bool("dry-run", false, "Show what would be uploaded without uploading")
if err := fs.Parse(args); err != nil {
log.Fatalf("parse flags: %v", err)
}
if *inputDir == "" {
fmt.Fprintln(os.Stderr, "error: --input is required (directory with Parquet files)")
fs.Usage()
os.Exit(1)
}
hfToken := *token
if hfToken == "" {
hfToken = os.Getenv("HF_TOKEN")
}
if hfToken == "" {
home, err := os.UserHomeDir()
if err == nil {
data, err := os.ReadFile(filepath.Join(home, ".huggingface", "token"))
if err == nil {
hfToken = strings.TrimSpace(string(data))
}
}
}
if hfToken == "" && !*dryRun {
fmt.Fprintln(os.Stderr, "error: HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)")
os.Exit(1)
}
splits := []string{"train", "valid", "test"}
type uploadEntry struct {
local string
remote string
}
var filesToUpload []uploadEntry
for _, split := range splits {
path := filepath.Join(*inputDir, split+".parquet")
if _, err := os.Stat(path); os.IsNotExist(err) {
continue
}
filesToUpload = append(filesToUpload, uploadEntry{path, fmt.Sprintf("data/%s.parquet", split)})
}
// Check for dataset card in parent directory.
cardPath := filepath.Join(*inputDir, "..", "dataset_card.md")
if _, err := os.Stat(cardPath); err == nil {
filesToUpload = append(filesToUpload, uploadEntry{cardPath, "README.md"})
}
if len(filesToUpload) == 0 {
fmt.Fprintln(os.Stderr, "error: no Parquet files found in input directory")
os.Exit(1)
}
if *dryRun {
fmt.Printf("Dry run: would publish to %s\n", *repoID)
if *public {
fmt.Println(" Visibility: public")
} else {
fmt.Println(" Visibility: private")
}
for _, f := range filesToUpload {
info, _ := os.Stat(f.local)
sizeMB := float64(info.Size()) / 1024 / 1024
fmt.Printf(" %s → %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB)
}
return
}
fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", *repoID)
for _, f := range filesToUpload {
if err := uploadFileToHF(hfToken, *repoID, f.local, f.remote); err != nil {
log.Fatalf("upload %s: %v", f.local, err)
}
fmt.Printf(" Uploaded %s → %s\n", filepath.Base(f.local), f.remote)
}
fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", *repoID)
}
// uploadFileToHF uploads a file to a HuggingFace dataset repo via the Hub API.
func uploadFileToHF(token, repoID, localPath, remotePath string) error {
data, err := os.ReadFile(localPath)
if err != nil {
return fmt.Errorf("read %s: %w", localPath, err)
}
url := fmt.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath)
req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data))
if err != nil {
return fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/octet-stream")
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("upload request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("upload failed: HTTP %d: %s", resp.StatusCode, string(body))
}
return nil
}