diff --git a/go.mod b/go.mod index dd2f83d..72b6aa8 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,29 @@ module forge.lthn.ai/lthn/lem go 1.25.6 require ( + github.com/marcboeker/go-duckdb v1.8.5 + github.com/parquet-go/parquet-go v0.27.0 +) + +require ( + github.com/andybalholm/brotli v1.1.1 // indirect github.com/apache/arrow-go/v18 v18.1.0 // indirect github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/google/flatbuffers v25.1.24+incompatible // indirect github.com/google/uuid v1.6.0 // indirect - github.com/hupe1980/go-huggingface v0.0.15 // indirect github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/cpuid/v2 v2.2.9 // indirect - github.com/marcboeker/go-duckdb v1.8.5 // indirect + github.com/parquet-go/bitpack v1.0.0 // indirect + github.com/parquet-go/jsonlite v1.0.0 // indirect github.com/pierrec/lz4/v4 v4.1.22 // indirect + github.com/twpayne/go-geom v1.6.1 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c // indirect golang.org/x/mod v0.22.0 // indirect golang.org/x/sync v0.10.0 // indirect - golang.org/x/sys v0.29.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/tools v0.29.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + google.golang.org/protobuf v1.36.1 // indirect ) diff --git a/go.sum b/go.sum index 6c74149..fed5a96 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,61 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= +github.com/alecthomas/assert/v2 v2.10.0 h1:jjRCHsj6hBJhkmhznrCzoNpbA3zqy0fYiUcYZP/GkPY= +github.com/alecthomas/assert/v2 v2.10.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= +github.com/alecthomas/repr v0.4.0 h1:GhI2A8MACjfegCPVq9f1FLvIBS+DrQ2KQBFZP1iFzXc= +github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y= github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0= +github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= +github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o= github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hupe1980/go-huggingface v0.0.15 h1:tTWmUGGunC/BYz4hrwS8SSVtMYVYjceG2uhL8HxeXvw= -github.com/hupe1980/go-huggingface v0.0.15/go.mod h1:IRvsik3+b9BJyw9hCfw1arI6gDObcVto1UA8f3kt8mM= +github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= +github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= +github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= +github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0= github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= +github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= +github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= +github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA= +github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs= +github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU= +github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0= +github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g= +github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4= +github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= golang.org/x/exp v0.0.0-20250128182459-e0ece0dbea4c h1:KL/ZBHXgKGVmuZBZ01Lt57yE5ws8ZPSkkihmEyq7FXc= @@ -26,9 +64,15 @@ golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= -golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/tools v0.29.0 h1:Xx0h3TtM9rzQpQuR4dKLrdglAmCEN5Oi+P74JdhdzXE= golang.org/x/tools v0.29.0/go.mod h1:KMQVMRsVxU6nHCFXrBPhDB8XncLNLM0lIy/F14RP588= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.15.1 h1:FNy7N6OUZVUaWG9pTiD+jlhdQ3lMP+/LcTpJ6+a8sQ0= +gonum.org/v1/gonum v0.15.1/go.mod h1:eZTZuRFrzu5pcyjN5wJhcIhnUdNijYxX1T2IcrOGY0o= +google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= +google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 8b0d1f6..935518e 100644 --- a/main.go +++ b/main.go @@ -21,6 +21,10 @@ Commands: expand Generate expansion responses via trained LEM model conv Generate conversational training data ingest Ingest benchmark data into InfluxDB + parquet Export JSONL training splits to Parquet for HuggingFace + publish Push Parquet files to HuggingFace dataset repo + metrics Push DuckDB golden set stats to InfluxDB + convert Convert MLX LoRA adapter to HuggingFace PEFT format ` func main() { @@ -46,6 +50,14 @@ func main() { lem.RunConv(os.Args[2:]) case "ingest": lem.RunIngest(os.Args[2:]) + case "parquet": + lem.RunParquet(os.Args[2:]) + case "publish": + lem.RunPublish(os.Args[2:]) + case "metrics": + lem.RunMetrics(os.Args[2:]) + case "convert": + lem.RunConvert(os.Args[2:]) default: fmt.Fprintf(os.Stderr, "unknown command: %s\n\n%s", os.Args[1], usage) os.Exit(1) diff --git a/pkg/lem/convert.go b/pkg/lem/convert.go new file mode 100644 index 0000000..3dd4d6d --- /dev/null +++ b/pkg/lem/convert.go @@ -0,0 +1,349 @@ +package lem + +import ( + "encoding/binary" + "encoding/json" + "flag" + "fmt" + "log" + "math" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" +) + +// RunConvert is the CLI entry point for the convert command. +// Converts MLX LoRA adapters to HuggingFace PEFT format: +// - Key renaming: model.layers.N.module.lora_a → base_model.model.model.layers.N.module.lora_A.default.weight +// - Transpose: MLX (in, rank) → PEFT (rank, in) +// - Config generation: adapter_config.json with lora_alpha = scale × rank +func RunConvert(args []string) { + fs := flag.NewFlagSet("convert", flag.ExitOnError) + + safetensorsPath := fs.String("input", "", "Path to MLX .safetensors file (required)") + configPath := fs.String("config", "", "Path to MLX adapter_config.json (required)") + outputDir := fs.String("output", "./peft_output", "Output directory for PEFT adapter") + baseModel := fs.String("base-model", "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "HuggingFace base model ID") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *safetensorsPath == "" || *configPath == "" { + fmt.Fprintln(os.Stderr, "error: --input and --config are required") + fs.Usage() + os.Exit(1) + } + + if err := convertMLXtoPEFT(*safetensorsPath, *configPath, *outputDir, *baseModel); err != nil { + log.Fatalf("convert: %v", err) + } + + fmt.Printf("Converted to: %s\n", *outputDir) +} + +var ( + loraARe = regexp.MustCompile(`\.lora_a$`) + loraBRe = regexp.MustCompile(`\.lora_b$`) + layerRe = regexp.MustCompile(`layers\.(\d+)`) + moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`) +) + +// renameMLXKey converts an MLX tensor key to PEFT format. +func renameMLXKey(mlxKey string) string { + key := mlxKey + key = loraARe.ReplaceAllString(key, ".lora_A.default.weight") + key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight") + key = "base_model.model." + key + return key +} + +// safetensorsHeader represents the header of a safetensors file. +type safetensorsHeader struct { + Metadata map[string]string `json:"__metadata__,omitempty"` + Tensors map[string]safetensorsTensorInfo `json:"-"` +} + +type safetensorsTensorInfo struct { + Dtype string `json:"dtype"` + Shape []int `json:"shape"` + DataOffsets [2]int `json:"data_offsets"` +} + +// readSafetensors reads a safetensors file and returns tensor name→data+info pairs. +func readSafetensors(path string) (map[string]safetensorsTensorInfo, []byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, fmt.Errorf("read file: %w", err) + } + + if len(data) < 8 { + return nil, nil, fmt.Errorf("file too small") + } + + headerSize := int(binary.LittleEndian.Uint64(data[:8])) + if 8+headerSize > len(data) { + return nil, nil, fmt.Errorf("invalid header size %d", headerSize) + } + + headerJSON := data[8 : 8+headerSize] + tensorData := data[8+headerSize:] + + // Parse header as a generic map since tensors are top-level keys. + var rawHeader map[string]json.RawMessage + if err := json.Unmarshal(headerJSON, &rawHeader); err != nil { + return nil, nil, fmt.Errorf("parse header: %w", err) + } + + tensors := make(map[string]safetensorsTensorInfo) + for key, raw := range rawHeader { + if key == "__metadata__" { + continue + } + var info safetensorsTensorInfo + if err := json.Unmarshal(raw, &info); err != nil { + return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err) + } + tensors[key] = info + } + + return tensors, tensorData, nil +} + +// getTensorData extracts raw bytes for a tensor from the data section. +func getTensorData(info safetensorsTensorInfo, allData []byte) []byte { + return allData[info.DataOffsets[0]:info.DataOffsets[1]] +} + +// transposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows). +func transposeFloat32(data []byte, rows, cols int) []byte { + if len(data) != rows*cols*4 { + return data // size mismatch, return as-is + } + + result := make([]byte, len(data)) + for r := 0; r < rows; r++ { + for c := 0; c < cols; c++ { + srcOff := (r*cols + c) * 4 + dstOff := (c*rows + r) * 4 + copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4]) + } + } + return result +} + +// transposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows). +func transposeFloat16(data []byte, rows, cols int) []byte { + if len(data) != rows*cols*2 { + return data + } + + result := make([]byte, len(data)) + for r := 0; r < rows; r++ { + for c := 0; c < cols; c++ { + srcOff := (r*cols + c) * 2 + dstOff := (c*rows + r) * 2 + copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2]) + } + } + return result +} + +// transposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows). +func transposeBFloat16(data []byte, rows, cols int) []byte { + return transposeFloat16(data, rows, cols) // same element size +} + +// writeSafetensors writes tensors to a safetensors file. +func writeSafetensors(path string, tensors map[string]safetensorsTensorInfo, tensorData map[string][]byte) error { + // Sort keys for deterministic output. + keys := make([]string, 0, len(tensors)) + for k := range tensors { + keys = append(keys, k) + } + sort.Strings(keys) + + // Compute offsets. + offset := 0 + updatedTensors := make(map[string]safetensorsTensorInfo) + for _, k := range keys { + info := tensors[k] + data := tensorData[k] + info.DataOffsets = [2]int{offset, offset + len(data)} + updatedTensors[k] = info + offset += len(data) + } + + // Build header JSON. + headerMap := make(map[string]interface{}) + for k, info := range updatedTensors { + headerMap[k] = info + } + + headerJSON, err := json.Marshal(headerMap) + if err != nil { + return fmt.Errorf("marshal header: %w", err) + } + + // Write file: 8-byte header size + header JSON + tensor data. + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("create %s: %w", path, err) + } + defer f.Close() + + headerSizeBuf := make([]byte, 8) + binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON))) + + if _, err := f.Write(headerSizeBuf); err != nil { + return err + } + if _, err := f.Write(headerJSON); err != nil { + return err + } + + for _, k := range keys { + if _, err := f.Write(tensorData[k]); err != nil { + return err + } + } + + return nil +} + +// convertMLXtoPEFT converts an MLX LoRA adapter to PEFT format. +func convertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) + } + + // Read MLX tensors. + tensors, tensorData, err := readSafetensors(safetensorsPath) + if err != nil { + return fmt.Errorf("read safetensors: %w", err) + } + log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath) + + // Rename and transpose tensors. + peftTensors := make(map[string]safetensorsTensorInfo) + peftData := make(map[string][]byte) + + for mlxKey, info := range tensors { + peftKey := renameMLXKey(mlxKey) + data := getTensorData(info, tensorData) + + // Transpose: swap shape and transpose data. + if len(info.Shape) == 2 { + rows, cols := info.Shape[0], info.Shape[1] + + switch info.Dtype { + case "F32": + data = transposeFloat32(data, rows, cols) + case "F16": + data = transposeFloat16(data, rows, cols) + case "BF16": + data = transposeBFloat16(data, rows, cols) + } + + info.Shape = []int{cols, rows} + } + + peftTensors[peftKey] = info + peftData[peftKey] = data + } + + // Write PEFT safetensors. + outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors") + if err := writeSafetensors(outSafetensors, peftTensors, peftData); err != nil { + return fmt.Errorf("write safetensors: %w", err) + } + + // Read MLX config for LoRA parameters. + cfgData, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + + var mlxConfig struct { + LoraParameters struct { + Rank int `json:"rank"` + Scale float64 `json:"scale"` + Dropout float64 `json:"dropout"` + } `json:"lora_parameters"` + } + if err := json.Unmarshal(cfgData, &mlxConfig); err != nil { + return fmt.Errorf("parse config: %w", err) + } + + rank := mlxConfig.LoraParameters.Rank + if rank == 0 { + rank = 8 + } + scale := mlxConfig.LoraParameters.Scale + if scale == 0 { + scale = 20.0 + } + + // Determine target modules from tensor keys. + modules := make(map[string]bool) + layers := make(map[int]bool) + for k := range tensors { + if m := moduleRe.FindStringSubmatch(k); m != nil { + parts := strings.Split(m[1], ".") + modules[parts[len(parts)-1]] = true + } + if m := layerRe.FindStringSubmatch(k); m != nil { + n, _ := strconv.Atoi(m[1]) + layers[n] = true + } + } + + sortedModules := make([]string, 0, len(modules)) + for m := range modules { + sortedModules = append(sortedModules, m) + } + sort.Strings(sortedModules) + + sortedLayers := make([]int, 0, len(layers)) + for l := range layers { + sortedLayers = append(sortedLayers, l) + } + sort.Ints(sortedLayers) + + // Write PEFT adapter_config.json. + peftConfig := map[string]interface{}{ + "auto_mapping": nil, + "base_model_name_or_path": baseModelName, + "bias": "none", + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layers_pattern": nil, + "layers_to_transform": sortedLayers, + "lora_alpha": math.Round(scale * float64(rank)), + "lora_dropout": mlxConfig.LoraParameters.Dropout, + "modules_to_save": nil, + "peft_type": "LORA", + "r": rank, + "revision": nil, + "target_modules": sortedModules, + "task_type": "CAUSAL_LM", + } + + cfgJSON, err := json.MarshalIndent(peftConfig, "", " ") + if err != nil { + return fmt.Errorf("marshal peft config: %w", err) + } + + if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil { + return fmt.Errorf("write adapter_config.json: %w", err) + } + + log.Printf("converted %d tensors, %d layers, target modules: %v", + len(peftTensors), len(sortedLayers), sortedModules) + + return nil +} diff --git a/pkg/lem/convert_test.go b/pkg/lem/convert_test.go new file mode 100644 index 0000000..7104b6c --- /dev/null +++ b/pkg/lem/convert_test.go @@ -0,0 +1,198 @@ +package lem + +import ( + "encoding/binary" + "encoding/json" + "math" + "os" + "path/filepath" + "testing" +) + +func TestRenameMLXKey(t *testing.T) { + tests := []struct { + input string + want string + }{ + { + "model.layers.12.self_attn.q_proj.lora_a", + "base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight", + }, + { + "model.layers.0.self_attn.v_proj.lora_b", + "base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight", + }, + { + "model.layers.5.mlp.gate_proj.lora_a", + "base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight", + }, + } + + for _, tt := range tests { + got := renameMLXKey(tt.input) + if got != tt.want { + t.Errorf("renameMLXKey(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestTransposeFloat32(t *testing.T) { + // 2x3 matrix: [[1, 2, 3], [4, 5, 6]] + data := make([]byte, 2*3*4) + for i, v := range []float32{1, 2, 3, 4, 5, 6} { + binary.LittleEndian.PutUint32(data[i*4:], math.Float32bits(v)) + } + + result := transposeFloat32(data, 2, 3) + + // Expected: 3x2 matrix: [[1, 4], [2, 5], [3, 6]] + expected := []float32{1, 4, 2, 5, 3, 6} + for i, want := range expected { + got := math.Float32frombits(binary.LittleEndian.Uint32(result[i*4:])) + if got != want { + t.Errorf("result[%d] = %f, want %f", i, got, want) + } + } +} + +func TestConvertMLXtoPEFT(t *testing.T) { + dir := t.TempDir() + + // Create a minimal MLX safetensors file with one lora_a and one lora_b tensor. + // Shape: lora_a is (in=4, rank=2), lora_b is (rank=2, out=4) + tensors := map[string]safetensorsTensorInfo{ + "model.layers.0.self_attn.q_proj.lora_a": {Dtype: "F32", Shape: []int{4, 2}}, + "model.layers.0.self_attn.q_proj.lora_b": {Dtype: "F32", Shape: []int{2, 4}}, + } + + // Create tensor data: 4x2=8 floats and 2x4=8 floats. + loraAData := make([]byte, 4*2*4) + for i := 0; i < 8; i++ { + binary.LittleEndian.PutUint32(loraAData[i*4:], math.Float32bits(float32(i+1))) + } + loraBData := make([]byte, 2*4*4) + for i := 0; i < 8; i++ { + binary.LittleEndian.PutUint32(loraBData[i*4:], math.Float32bits(float32(10+i))) + } + + tensorData := make(map[string][]byte) + tensorData["model.layers.0.self_attn.q_proj.lora_a"] = loraAData + tensorData["model.layers.0.self_attn.q_proj.lora_b"] = loraBData + + sfPath := filepath.Join(dir, "adapters.safetensors") + if err := writeSafetensors(sfPath, tensors, tensorData); err != nil { + t.Fatalf("write test safetensors: %v", err) + } + + // Create MLX config. + mlxConfig := map[string]interface{}{ + "lora_parameters": map[string]interface{}{ + "rank": 8, + "scale": 20.0, + "dropout": 0.0, + }, + } + cfgData, _ := json.Marshal(mlxConfig) + cfgPath := filepath.Join(dir, "adapter_config.json") + os.WriteFile(cfgPath, cfgData, 0644) + + // Convert. + outputDir := filepath.Join(dir, "peft_output") + if err := convertMLXtoPEFT(sfPath, cfgPath, outputDir, "test-model"); err != nil { + t.Fatalf("convert: %v", err) + } + + // Check output files exist. + if _, err := os.Stat(filepath.Join(outputDir, "adapter_model.safetensors")); err != nil { + t.Error("missing adapter_model.safetensors") + } + if _, err := os.Stat(filepath.Join(outputDir, "adapter_config.json")); err != nil { + t.Error("missing adapter_config.json") + } + + // Read and verify PEFT config. + peftCfgData, err := os.ReadFile(filepath.Join(outputDir, "adapter_config.json")) + if err != nil { + t.Fatalf("read peft config: %v", err) + } + + var peftConfig map[string]interface{} + if err := json.Unmarshal(peftCfgData, &peftConfig); err != nil { + t.Fatalf("parse peft config: %v", err) + } + + if peftConfig["peft_type"] != "LORA" { + t.Errorf("peft_type = %v, want LORA", peftConfig["peft_type"]) + } + if peftConfig["base_model_name_or_path"] != "test-model" { + t.Errorf("base_model = %v, want test-model", peftConfig["base_model_name_or_path"]) + } + + // Check that lora_alpha = scale * rank = 20 * 8 = 160. + if alpha, ok := peftConfig["lora_alpha"].(float64); !ok || alpha != 160 { + t.Errorf("lora_alpha = %v, want 160", peftConfig["lora_alpha"]) + } + + // Verify converted safetensors has PEFT-format keys. + peftTensors, _, err := readSafetensors(filepath.Join(outputDir, "adapter_model.safetensors")) + if err != nil { + t.Fatalf("read peft safetensors: %v", err) + } + + expectedKeys := []string{ + "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight", + "base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight", + } + for _, k := range expectedKeys { + if _, ok := peftTensors[k]; !ok { + t.Errorf("missing expected PEFT key: %s", k) + } + } + + // Verify shapes are transposed: lora_a (4,2) → (2,4), lora_b (2,4) → (4,2). + loraAInfo := peftTensors["base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"] + if loraAInfo.Shape[0] != 2 || loraAInfo.Shape[1] != 4 { + t.Errorf("lora_A shape = %v, want [2, 4]", loraAInfo.Shape) + } +} + +func TestReadWriteSafetensorsRoundtrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.safetensors") + + original := map[string]safetensorsTensorInfo{ + "weight_a": {Dtype: "F32", Shape: []int{2, 3}}, + } + data := map[string][]byte{ + "weight_a": make([]byte, 2*3*4), + } + for i := 0; i < 6; i++ { + binary.LittleEndian.PutUint32(data["weight_a"][i*4:], math.Float32bits(float32(i))) + } + + if err := writeSafetensors(path, original, data); err != nil { + t.Fatalf("write: %v", err) + } + + readTensors, readData, err := readSafetensors(path) + if err != nil { + t.Fatalf("read: %v", err) + } + + if len(readTensors) != 1 { + t.Fatalf("expected 1 tensor, got %d", len(readTensors)) + } + + info := readTensors["weight_a"] + if info.Dtype != "F32" { + t.Errorf("dtype = %s, want F32", info.Dtype) + } + if info.Shape[0] != 2 || info.Shape[1] != 3 { + t.Errorf("shape = %v, want [2, 3]", info.Shape) + } + + got := getTensorData(info, readData) + if len(got) != 24 { + t.Errorf("data length = %d, want 24", len(got)) + } +} diff --git a/pkg/lem/metrics.go b/pkg/lem/metrics.go new file mode 100644 index 0000000..5b701cc --- /dev/null +++ b/pkg/lem/metrics.go @@ -0,0 +1,126 @@ +package lem + +import ( + "flag" + "fmt" + "log" + "os" + "time" +) + +const targetTotal = 15000 + +// RunMetrics is the CLI entry point for the metrics command. +// Reads golden set stats from DuckDB and pushes them to InfluxDB as +// golden_set_stats, golden_set_domain, and golden_set_voice measurements. +func RunMetrics(args []string) { + fs := flag.NewFlagSet("metrics", flag.ExitOnError) + + dbPath := fs.String("db", "", "DuckDB database path (defaults to LEM_DB env)") + influxURL := fs.String("influx", "", "InfluxDB URL") + influxDB := fs.String("influx-db", "", "InfluxDB database name") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *dbPath == "" { + *dbPath = os.Getenv("LEM_DB") + } + if *dbPath == "" { + fmt.Fprintln(os.Stderr, "error: --db or LEM_DB required (path to DuckDB file)") + os.Exit(1) + } + + db, err := OpenDB(*dbPath) + if err != nil { + log.Fatalf("open db: %v", err) + } + defer db.Close() + + // Query overall stats. + var total, domains, voices int + var avgGenTime, avgChars float64 + + err = db.conn.QueryRow(` + SELECT count(*), count(DISTINCT domain), count(DISTINCT voice), + coalesce(avg(gen_time), 0), coalesce(avg(char_count), 0) + FROM golden_set + `).Scan(&total, &domains, &voices, &avgGenTime, &avgChars) + if err != nil { + log.Fatalf("query golden_set stats: %v", err) + } + + if total == 0 { + fmt.Println("No golden set data in DuckDB.") + return + } + + nowNs := time.Now().UTC().UnixNano() + pct := float64(total) / float64(targetTotal) * 100.0 + + var lines []string + + // Overall stats measurement. + lines = append(lines, fmt.Sprintf( + "golden_set_stats total_examples=%di,domains=%di,voices=%di,avg_gen_time=%.2f,avg_response_chars=%.0f,completion_pct=%.1f %d", + total, domains, voices, avgGenTime, avgChars, pct, nowNs, + )) + + // Per-domain stats. + domainRows, err := db.conn.Query(` + SELECT domain, count(*) AS n, avg(gen_time) AS avg_t + FROM golden_set GROUP BY domain + `) + if err != nil { + log.Fatalf("query domains: %v", err) + } + domainCount := 0 + for domainRows.Next() { + var domain string + var n int + var avgT float64 + if err := domainRows.Scan(&domain, &n, &avgT); err != nil { + log.Fatalf("scan domain row: %v", err) + } + lines = append(lines, fmt.Sprintf( + "golden_set_domain,domain=%s count=%di,avg_gen_time=%.2f %d", + escapeLp(domain), n, avgT, nowNs, + )) + domainCount++ + } + domainRows.Close() + + // Per-voice stats. + voiceRows, err := db.conn.Query(` + SELECT voice, count(*) AS n, avg(char_count) AS avg_c, avg(gen_time) AS avg_t + FROM golden_set GROUP BY voice + `) + if err != nil { + log.Fatalf("query voices: %v", err) + } + voiceCount := 0 + for voiceRows.Next() { + var voice string + var n int + var avgC, avgT float64 + if err := voiceRows.Scan(&voice, &n, &avgC, &avgT); err != nil { + log.Fatalf("scan voice row: %v", err) + } + lines = append(lines, fmt.Sprintf( + "golden_set_voice,voice=%s count=%di,avg_chars=%.0f,avg_gen_time=%.2f %d", + escapeLp(voice), n, avgC, avgT, nowNs, + )) + voiceCount++ + } + voiceRows.Close() + + // Write to InfluxDB. + influx := NewInfluxClient(*influxURL, *influxDB) + if err := influx.WriteLp(lines); err != nil { + log.Fatalf("write metrics: %v", err) + } + + fmt.Printf("Wrote metrics to InfluxDB: %d examples, %d domains, %d voices (%d points)\n", + total, domainCount, voiceCount, len(lines)) +} diff --git a/pkg/lem/parquet.go b/pkg/lem/parquet.go new file mode 100644 index 0000000..0d3e136 --- /dev/null +++ b/pkg/lem/parquet.go @@ -0,0 +1,162 @@ +package lem + +import ( + "bufio" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "strings" + + "github.com/parquet-go/parquet-go" +) + +// ParquetRow is the schema for exported Parquet files. +type ParquetRow struct { + Prompt string `parquet:"prompt"` + Response string `parquet:"response"` + System string `parquet:"system"` + Messages string `parquet:"messages"` +} + +// RunParquet is the CLI entry point for the parquet command. +// Reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) and +// writes Parquet files with snappy compression for HuggingFace datasets. +func RunParquet(args []string) { + fs := flag.NewFlagSet("parquet", flag.ExitOnError) + + trainingDir := fs.String("input", "", "Directory containing train.jsonl, valid.jsonl, test.jsonl (required)") + outputDir := fs.String("output", "", "Output directory for Parquet files (defaults to input/parquet)") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *trainingDir == "" { + fmt.Fprintln(os.Stderr, "error: --input is required (directory with JSONL splits)") + fs.Usage() + os.Exit(1) + } + + if *outputDir == "" { + *outputDir = filepath.Join(*trainingDir, "parquet") + } + + if err := os.MkdirAll(*outputDir, 0755); err != nil { + log.Fatalf("create output dir: %v", err) + } + + fmt.Printf("Exporting Parquet from %s → %s\n", *trainingDir, *outputDir) + + total := 0 + for _, split := range []string{"train", "valid", "test"} { + jsonlPath := filepath.Join(*trainingDir, split+".jsonl") + if _, err := os.Stat(jsonlPath); os.IsNotExist(err) { + fmt.Printf(" Skip: %s.jsonl not found\n", split) + continue + } + + n, err := exportSplitParquet(jsonlPath, *outputDir, split) + if err != nil { + log.Fatalf("export %s: %v", split, err) + } + total += n + } + + fmt.Printf("\nTotal: %d rows exported\n", total) +} + +// exportSplitParquet reads a JSONL file and writes a Parquet file for the split. +func exportSplitParquet(jsonlPath, outputDir, split string) (int, error) { + f, err := os.Open(jsonlPath) + if err != nil { + return 0, fmt.Errorf("open %s: %w", jsonlPath, err) + } + defer f.Close() + + var rows []ParquetRow + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + if text == "" { + continue + } + + var data struct { + Messages []ChatMessage `json:"messages"` + } + if err := json.Unmarshal([]byte(text), &data); err != nil { + continue + } + + var prompt, response, system string + for _, m := range data.Messages { + switch m.Role { + case "user": + if prompt == "" { + prompt = m.Content + } + case "assistant": + if response == "" { + response = m.Content + } + case "system": + if system == "" { + system = m.Content + } + } + } + + msgsJSON, _ := json.Marshal(data.Messages) + rows = append(rows, ParquetRow{ + Prompt: prompt, + Response: response, + System: system, + Messages: string(msgsJSON), + }) + } + + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("scan %s: %w", jsonlPath, err) + } + + if len(rows) == 0 { + fmt.Printf(" Skip: %s — no data\n", split) + return 0, nil + } + + outPath := filepath.Join(outputDir, split+".parquet") + + out, err := os.Create(outPath) + if err != nil { + return 0, fmt.Errorf("create %s: %w", outPath, err) + } + + writer := parquet.NewGenericWriter[ParquetRow](out, + parquet.Compression(&parquet.Snappy), + ) + + if _, err := writer.Write(rows); err != nil { + out.Close() + return 0, fmt.Errorf("write parquet rows: %w", err) + } + + if err := writer.Close(); err != nil { + out.Close() + return 0, fmt.Errorf("close parquet writer: %w", err) + } + + if err := out.Close(); err != nil { + return 0, fmt.Errorf("close file: %w", err) + } + + info, _ := os.Stat(outPath) + sizeMB := float64(info.Size()) / 1024 / 1024 + fmt.Printf(" %s.parquet: %d rows (%.1f MB)\n", split, len(rows), sizeMB) + + return len(rows), nil +} diff --git a/pkg/lem/parquet_test.go b/pkg/lem/parquet_test.go new file mode 100644 index 0000000..122ea88 --- /dev/null +++ b/pkg/lem/parquet_test.go @@ -0,0 +1,143 @@ +package lem + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "testing" + + "github.com/parquet-go/parquet-go" +) + +func TestExportSplitParquet(t *testing.T) { + dir := t.TempDir() + inputPath := filepath.Join(dir, "train.jsonl") + outputDir := filepath.Join(dir, "output") + os.MkdirAll(outputDir, 0755) + + // Write test JSONL. + convs := []TrainingExample{ + {Messages: []ChatMessage{ + {Role: "user", Content: "What is wisdom?"}, + {Role: "assistant", Content: "The application of understanding."}, + }}, + {Messages: []ChatMessage{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Tell me about ethics."}, + {Role: "assistant", Content: "Ethics concerns right action."}, + }}, + } + + f, _ := os.Create(inputPath) + for _, c := range convs { + data, _ := json.Marshal(c) + f.Write(data) + f.WriteString("\n") + } + f.Close() + + n, err := exportSplitParquet(inputPath, outputDir, "train") + if err != nil { + t.Fatalf("export: %v", err) + } + if n != 2 { + t.Errorf("expected 2 rows, got %d", n) + } + + // Verify Parquet file exists and is readable. + outPath := filepath.Join(outputDir, "train.parquet") + pf, err := os.Open(outPath) + if err != nil { + t.Fatalf("open parquet: %v", err) + } + defer pf.Close() + + info, _ := pf.Stat() + reader := parquet.NewGenericReader[ParquetRow](pf) + defer reader.Close() + + rows := make([]ParquetRow, 10) + read, err := reader.Read(rows) + if err != nil && err != io.EOF { + t.Fatalf("read parquet: %v", err) + } + if read != 2 { + t.Errorf("expected 2 rows in parquet, got %d", read) + } + + if rows[0].Prompt != "What is wisdom?" { + t.Errorf("unexpected prompt: %s", rows[0].Prompt) + } + if rows[0].Response != "The application of understanding." { + t.Errorf("unexpected response: %s", rows[0].Response) + } + if rows[1].System != "You are helpful." { + t.Errorf("expected system message, got: %s", rows[1].System) + } + + if info.Size() == 0 { + t.Error("parquet file is empty") + } +} + +func TestExportSplitParquetEmpty(t *testing.T) { + dir := t.TempDir() + inputPath := filepath.Join(dir, "empty.jsonl") + outputDir := filepath.Join(dir, "output") + os.MkdirAll(outputDir, 0755) + + // Write empty JSONL. + os.WriteFile(inputPath, []byte("\n\n"), 0644) + + n, err := exportSplitParquet(inputPath, outputDir, "test") + if err != nil { + t.Fatalf("export: %v", err) + } + if n != 0 { + t.Errorf("expected 0 rows for empty file, got %d", n) + } +} + +func TestExportSplitParquetMessages(t *testing.T) { + dir := t.TempDir() + inputPath := filepath.Join(dir, "valid.jsonl") + outputDir := filepath.Join(dir, "output") + os.MkdirAll(outputDir, 0755) + + conv := TrainingExample{Messages: []ChatMessage{ + {Role: "user", Content: "hi"}, + {Role: "assistant", Content: "hello"}, + }} + + f, _ := os.Create(inputPath) + data, _ := json.Marshal(conv) + f.Write(data) + f.WriteString("\n") + f.Close() + + n, err := exportSplitParquet(inputPath, outputDir, "valid") + if err != nil { + t.Fatalf("export: %v", err) + } + if n != 1 { + t.Errorf("expected 1 row, got %d", n) + } + + // Verify messages field contains valid JSON. + pf, _ := os.Open(filepath.Join(outputDir, "valid.parquet")) + defer pf.Close() + reader := parquet.NewGenericReader[ParquetRow](pf) + defer reader.Close() + + rows := make([]ParquetRow, 1) + reader.Read(rows) + + var msgs []ChatMessage + if err := json.Unmarshal([]byte(rows[0].Messages), &msgs); err != nil { + t.Fatalf("parse messages JSON: %v", err) + } + if len(msgs) != 2 { + t.Errorf("expected 2 messages in JSON, got %d", len(msgs)) + } +} diff --git a/pkg/lem/publish.go b/pkg/lem/publish.go new file mode 100644 index 0000000..08170b6 --- /dev/null +++ b/pkg/lem/publish.go @@ -0,0 +1,138 @@ +package lem + +import ( + "bytes" + "flag" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// RunPublish is the CLI entry point for the publish command. +// Pushes Parquet files and an optional dataset card to HuggingFace. +func RunPublish(args []string) { + fs := flag.NewFlagSet("publish", flag.ExitOnError) + + inputDir := fs.String("input", "", "Directory containing Parquet files (required)") + repoID := fs.String("repo", "lthn/LEM-golden-set", "HuggingFace dataset repo ID") + public := fs.Bool("public", false, "Make dataset public") + token := fs.String("token", "", "HuggingFace API token (defaults to HF_TOKEN env)") + dryRun := fs.Bool("dry-run", false, "Show what would be uploaded without uploading") + + if err := fs.Parse(args); err != nil { + log.Fatalf("parse flags: %v", err) + } + + if *inputDir == "" { + fmt.Fprintln(os.Stderr, "error: --input is required (directory with Parquet files)") + fs.Usage() + os.Exit(1) + } + + hfToken := *token + if hfToken == "" { + hfToken = os.Getenv("HF_TOKEN") + } + if hfToken == "" { + home, err := os.UserHomeDir() + if err == nil { + data, err := os.ReadFile(filepath.Join(home, ".huggingface", "token")) + if err == nil { + hfToken = strings.TrimSpace(string(data)) + } + } + } + + if hfToken == "" && !*dryRun { + fmt.Fprintln(os.Stderr, "error: HuggingFace token required (--token, HF_TOKEN env, or ~/.huggingface/token)") + os.Exit(1) + } + + splits := []string{"train", "valid", "test"} + type uploadEntry struct { + local string + remote string + } + var filesToUpload []uploadEntry + + for _, split := range splits { + path := filepath.Join(*inputDir, split+".parquet") + if _, err := os.Stat(path); os.IsNotExist(err) { + continue + } + filesToUpload = append(filesToUpload, uploadEntry{path, fmt.Sprintf("data/%s.parquet", split)}) + } + + // Check for dataset card in parent directory. + cardPath := filepath.Join(*inputDir, "..", "dataset_card.md") + if _, err := os.Stat(cardPath); err == nil { + filesToUpload = append(filesToUpload, uploadEntry{cardPath, "README.md"}) + } + + if len(filesToUpload) == 0 { + fmt.Fprintln(os.Stderr, "error: no Parquet files found in input directory") + os.Exit(1) + } + + if *dryRun { + fmt.Printf("Dry run: would publish to %s\n", *repoID) + if *public { + fmt.Println(" Visibility: public") + } else { + fmt.Println(" Visibility: private") + } + for _, f := range filesToUpload { + info, _ := os.Stat(f.local) + sizeMB := float64(info.Size()) / 1024 / 1024 + fmt.Printf(" %s → %s (%.1f MB)\n", filepath.Base(f.local), f.remote, sizeMB) + } + return + } + + fmt.Printf("Publishing to https://huggingface.co/datasets/%s\n", *repoID) + + for _, f := range filesToUpload { + if err := uploadFileToHF(hfToken, *repoID, f.local, f.remote); err != nil { + log.Fatalf("upload %s: %v", f.local, err) + } + fmt.Printf(" Uploaded %s → %s\n", filepath.Base(f.local), f.remote) + } + + fmt.Printf("\nPublished to https://huggingface.co/datasets/%s\n", *repoID) +} + +// uploadFileToHF uploads a file to a HuggingFace dataset repo via the Hub API. +func uploadFileToHF(token, repoID, localPath, remotePath string) error { + data, err := os.ReadFile(localPath) + if err != nil { + return fmt.Errorf("read %s: %w", localPath, err) + } + + url := fmt.Sprintf("https://huggingface.co/api/datasets/%s/upload/main/%s", repoID, remotePath) + + req, err := http.NewRequest(http.MethodPut, url, bytes.NewReader(data)) + if err != nil { + return fmt.Errorf("create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/octet-stream") + + client := &http.Client{Timeout: 120 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("upload request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 300 { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("upload failed: HTTP %d: %s", resp.StatusCode, string(body)) + } + + return nil +}