feat(ml): add format converters, data pipeline, and scoring agent
Some checks are pending
Security Scan / Go Vulnerability Check (push) Waiting to run
Security Scan / Secret Detection (push) Waiting to run
Security Scan / Dependency & Config Scan (push) Waiting to run

Port remaining lem-repo components into pkg/ml/:
- convert.go: safetensors reader/writer, MLX→PEFT converter
- gguf.go: GGUF v3 writer, MLX→GGUF LoRA converter
- export.go: training data JSONL export with split/filter
- parquet.go: Parquet export with snappy compression
- db.go: DuckDB wrapper for golden set and expansion prompts
- influx.go: InfluxDB v3 client for metrics/status
- ollama.go: Ollama model management (create/delete with adapters)
- status.go: training and generation status display
- expand.go: expansion generation pipeline (Backend interface)
- agent.go: scoring agent with probe running and InfluxDB push
- worker.go: distributed worker for LEM API task processing

Adds parquet-go and go-duckdb dependencies.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude 2026-02-15 23:46:24 +00:00
parent 3fdc3f3086
commit fcd1758b7d
No known key found for this signature in database
GPG key ID: AF404715446AEB41
13 changed files with 3332 additions and 1 deletions

16
go.mod
View file

@ -38,6 +38,8 @@ require (
github.com/Snider/Enchantrix v0.0.2 // indirect
github.com/TwiN/go-color v1.4.1 // indirect
github.com/adrg/xdg v0.5.3 // indirect
github.com/andybalholm/brotli v1.1.1 // indirect
github.com/apache/arrow-go/v18 v18.1.0 // indirect
github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect
@ -71,9 +73,11 @@ require (
github.com/go-openapi/jsonpointer v0.22.4 // indirect
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/godbus/dbus/v5 v5.2.2 // indirect
github.com/gofrs/flock v0.12.1 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/google/flatbuffers v25.1.24+incompatible // indirect
github.com/google/go-github/v39 v39.2.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
@ -85,11 +89,13 @@ require (
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/compress v1.18.3 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
github.com/leaanthony/u v1.1.1 // indirect
github.com/lmittmann/tint v1.1.2 // indirect
github.com/mailru/easyjson v0.9.1 // indirect
github.com/marcboeker/go-duckdb v1.8.5 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
@ -97,8 +103,12 @@ require (
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect
github.com/parquet-go/bitpack v1.0.0 // indirect
github.com/parquet-go/jsonlite v1.0.0 // indirect
github.com/parquet-go/parquet-go v0.27.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
@ -119,9 +129,9 @@ require (
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twpayne/go-geom v1.6.1 // indirect
github.com/ugorji/go/codec v1.3.0 // indirect
github.com/ulikunitz/xz v0.5.15 // indirect
github.com/unpoller/unifi/v5 v5.17.0 // indirect
github.com/wI2L/jsondiff v0.7.0 // indirect
github.com/wailsapp/go-webview2 v1.0.23 // indirect
github.com/wailsapp/wails/v3 v3.0.0-alpha.64 // indirect
@ -130,10 +140,14 @@ require (
github.com/xanzy/ssh-agent v0.3.3 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/zeebo/xxh3 v1.1.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.40.0 // indirect
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 // indirect
golang.org/x/tools v0.41.0 // indirect
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect
google.golang.org/grpc v1.76.0 // indirect
google.golang.org/protobuf v1.36.10 // indirect

33
go.sum
View file

@ -18,12 +18,17 @@ github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBi
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/Snider/Borg v0.2.0 h1:iCyDhY4WTXi39+FexRwXbn2YpZ2U9FUXVXDZk9xRCXQ=
github.com/Snider/Borg v0.2.0/go.mod h1:TqlKnfRo9okioHbgrZPfWjQsztBV0Nfskz4Om1/vdMY=
github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ=
github.com/TwiN/go-color v1.4.1 h1:mqG0P/KBgHKVqmtL5ye7K0/Gr4l6hTksPgTgMk3mUzc=
github.com/TwiN/go-color v1.4.1/go.mod h1:WcPf/jtiW95WBIsEeY1Lc/b8aaWoiqQpu5cf8WFxu+s=
github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78=
github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ=
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y=
github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU=
@ -115,6 +120,8 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
@ -127,6 +134,8 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o=
github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@ -153,6 +162,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kluctl/go-embed-python v0.0.0-3.13.1-20241219-1 h1:x1cSEj4Ug5mpuZgUHLvUmlc5r//KHFn6iYiRSrRcVy4=
@ -179,6 +190,8 @@ github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0=
github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8=
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
@ -206,10 +219,18 @@ github.com/ollama/ollama v0.15.4 h1:y841GH5lsi5j5BTFyX/E+UOC3Yiw+JBfdjBVRGw+I0M=
github.com/ollama/ollama v0.15.4/go.mod h1:4Yn3jw2hZ4VqyJ1XciYawDRE8bzv4RT3JiVZR1kCfwE=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA=
github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs=
github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU=
github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0=
github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g=
github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
@ -274,6 +295,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4=
github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028=
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
@ -292,10 +315,13 @@ github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQs
github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc=
github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
@ -314,6 +340,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
@ -323,6 +350,7 @@ golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHi
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
@ -346,11 +374,14 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 h1:O1cMQHRfwNpDfDJerqRoE2oD+AFlyid87D40L/OkkJo=
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
@ -359,6 +390,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=

1070
pkg/ml/agent.go Normal file

File diff suppressed because it is too large Load diff

303
pkg/ml/convert.go Normal file
View file

@ -0,0 +1,303 @@
package ml
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"math"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
)
var (
loraARe = regexp.MustCompile(`\.lora_a$`)
loraBRe = regexp.MustCompile(`\.lora_b$`)
layerRe = regexp.MustCompile(`layers\.(\d+)`)
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
)
// RenameMLXKey converts an MLX tensor key to PEFT format.
func RenameMLXKey(mlxKey string) string {
key := mlxKey
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
key = "base_model.model." + key
return key
}
// SafetensorsHeader represents the header of a safetensors file.
type SafetensorsHeader struct {
Metadata map[string]string `json:"__metadata__,omitempty"`
Tensors map[string]SafetensorsTensorInfo `json:"-"`
}
// SafetensorsTensorInfo describes a tensor's dtype, shape, and data location.
type SafetensorsTensorInfo struct {
Dtype string `json:"dtype"`
Shape []int `json:"shape"`
DataOffsets [2]int `json:"data_offsets"`
}
// ReadSafetensors reads a safetensors file and returns tensor info and raw data.
func ReadSafetensors(path string) (map[string]SafetensorsTensorInfo, []byte, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, nil, fmt.Errorf("read file: %w", err)
}
if len(data) < 8 {
return nil, nil, fmt.Errorf("file too small")
}
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
if 8+headerSize > len(data) {
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
}
headerJSON := data[8 : 8+headerSize]
tensorData := data[8+headerSize:]
var rawHeader map[string]json.RawMessage
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
return nil, nil, fmt.Errorf("parse header: %w", err)
}
tensors := make(map[string]SafetensorsTensorInfo)
for key, raw := range rawHeader {
if key == "__metadata__" {
continue
}
var info SafetensorsTensorInfo
if err := json.Unmarshal(raw, &info); err != nil {
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
}
tensors[key] = info
}
return tensors, tensorData, nil
}
// GetTensorData extracts raw bytes for a tensor from the data section.
func GetTensorData(info SafetensorsTensorInfo, allData []byte) []byte {
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
}
// TransposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
func TransposeFloat32(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*4 {
return data
}
result := make([]byte, len(data))
for r := range rows {
for c := range cols {
srcOff := (r*cols + c) * 4
dstOff := (c*rows + r) * 4
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
}
}
return result
}
// TransposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
func TransposeFloat16(data []byte, rows, cols int) []byte {
if len(data) != rows*cols*2 {
return data
}
result := make([]byte, len(data))
for r := range rows {
for c := range cols {
srcOff := (r*cols + c) * 2
dstOff := (c*rows + r) * 2
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
}
}
return result
}
// TransposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
func TransposeBFloat16(data []byte, rows, cols int) []byte {
return TransposeFloat16(data, rows, cols)
}
// WriteSafetensors writes tensors to a safetensors file.
func WriteSafetensors(path string, tensors map[string]SafetensorsTensorInfo, tensorData map[string][]byte) error {
keys := make([]string, 0, len(tensors))
for k := range tensors {
keys = append(keys, k)
}
sort.Strings(keys)
offset := 0
updatedTensors := make(map[string]SafetensorsTensorInfo)
for _, k := range keys {
info := tensors[k]
data := tensorData[k]
info.DataOffsets = [2]int{offset, offset + len(data)}
updatedTensors[k] = info
offset += len(data)
}
headerMap := make(map[string]interface{})
for k, info := range updatedTensors {
headerMap[k] = info
}
headerJSON, err := json.Marshal(headerMap)
if err != nil {
return fmt.Errorf("marshal header: %w", err)
}
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
headerSizeBuf := make([]byte, 8)
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
if _, err := f.Write(headerSizeBuf); err != nil {
return err
}
if _, err := f.Write(headerJSON); err != nil {
return err
}
for _, k := range keys {
if _, err := f.Write(tensorData[k]); err != nil {
return err
}
}
return nil
}
// ConvertMLXtoPEFT converts an MLX LoRA adapter to HuggingFace PEFT format.
func ConvertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
if err := os.MkdirAll(outputDir, 0755); err != nil {
return fmt.Errorf("create output dir: %w", err)
}
tensors, tensorData, err := ReadSafetensors(safetensorsPath)
if err != nil {
return fmt.Errorf("read safetensors: %w", err)
}
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
peftTensors := make(map[string]SafetensorsTensorInfo)
peftData := make(map[string][]byte)
for mlxKey, info := range tensors {
peftKey := RenameMLXKey(mlxKey)
data := GetTensorData(info, tensorData)
if len(info.Shape) == 2 {
rows, cols := info.Shape[0], info.Shape[1]
switch info.Dtype {
case "F32":
data = TransposeFloat32(data, rows, cols)
case "F16":
data = TransposeFloat16(data, rows, cols)
case "BF16":
data = TransposeBFloat16(data, rows, cols)
}
info.Shape = []int{cols, rows}
}
peftTensors[peftKey] = info
peftData[peftKey] = data
}
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
if err := WriteSafetensors(outSafetensors, peftTensors, peftData); err != nil {
return fmt.Errorf("write safetensors: %w", err)
}
cfgData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
var mlxConfig struct {
LoraParameters struct {
Rank int `json:"rank"`
Scale float64 `json:"scale"`
Dropout float64 `json:"dropout"`
} `json:"lora_parameters"`
}
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
return fmt.Errorf("parse config: %w", err)
}
rank := mlxConfig.LoraParameters.Rank
if rank == 0 {
rank = 8
}
scale := mlxConfig.LoraParameters.Scale
if scale == 0 {
scale = 20.0
}
modules := make(map[string]bool)
layers := make(map[int]bool)
for k := range tensors {
if m := moduleRe.FindStringSubmatch(k); m != nil {
parts := strings.Split(m[1], ".")
modules[parts[len(parts)-1]] = true
}
if m := layerRe.FindStringSubmatch(k); m != nil {
n, _ := strconv.Atoi(m[1])
layers[n] = true
}
}
sortedModules := make([]string, 0, len(modules))
for m := range modules {
sortedModules = append(sortedModules, m)
}
sort.Strings(sortedModules)
sortedLayers := make([]int, 0, len(layers))
for l := range layers {
sortedLayers = append(sortedLayers, l)
}
sort.Ints(sortedLayers)
peftConfig := map[string]interface{}{
"auto_mapping": nil,
"base_model_name_or_path": baseModelName,
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layers_pattern": nil,
"layers_to_transform": sortedLayers,
"lora_alpha": math.Round(scale * float64(rank)),
"lora_dropout": mlxConfig.LoraParameters.Dropout,
"modules_to_save": nil,
"peft_type": "LORA",
"r": rank,
"revision": nil,
"target_modules": sortedModules,
"task_type": "CAUSAL_LM",
}
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
if err != nil {
return fmt.Errorf("marshal peft config: %w", err)
}
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
return fmt.Errorf("write adapter_config.json: %w", err)
}
log.Printf("converted %d tensors, %d layers, target modules: %v",
len(peftTensors), len(sortedLayers), sortedModules)
return nil
}

241
pkg/ml/db.go Normal file
View file

@ -0,0 +1,241 @@
package ml
import (
"database/sql"
"fmt"
_ "github.com/marcboeker/go-duckdb"
)
// DB wraps a DuckDB connection.
type DB struct {
conn *sql.DB
path string
}
// OpenDB opens a DuckDB database file in read-only mode to avoid locking
// issues with the Python pipeline.
func OpenDB(path string) (*DB, error) {
conn, err := sql.Open("duckdb", path+"?access_mode=READ_ONLY")
if err != nil {
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
}
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
}
return &DB{conn: conn, path: path}, nil
}
// OpenDBReadWrite opens a DuckDB database in read-write mode.
func OpenDBReadWrite(path string) (*DB, error) {
conn, err := sql.Open("duckdb", path)
if err != nil {
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
}
if err := conn.Ping(); err != nil {
conn.Close()
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
}
return &DB{conn: conn, path: path}, nil
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.conn.Close()
}
// GoldenSetRow represents one row from the golden_set table.
type GoldenSetRow struct {
Idx int
SeedID string
Domain string
Voice string
Prompt string
Response string
GenTime float64
CharCount int
}
// ExpansionPromptRow represents one row from the expansion_prompts table.
type ExpansionPromptRow struct {
Idx int64
SeedID string
Region string
Domain string
Language string
Prompt string
PromptEn string
Priority int
Status string
}
// QueryGoldenSet returns all golden set rows with responses >= minChars.
func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) {
rows, err := db.conn.Query(
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
minChars,
)
if err != nil {
return nil, fmt.Errorf("query golden_set: %w", err)
}
defer rows.Close()
var result []GoldenSetRow
for rows.Next() {
var r GoldenSetRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
return nil, fmt.Errorf("scan golden_set row: %w", err)
}
result = append(result, r)
}
return result, rows.Err()
}
// CountGoldenSet returns the total count of golden set rows.
func (db *DB) CountGoldenSet() (int, error) {
var count int
err := db.conn.QueryRow("SELECT COUNT(*) FROM golden_set").Scan(&count)
if err != nil {
return 0, fmt.Errorf("count golden_set: %w", err)
}
return count, nil
}
// QueryExpansionPrompts returns expansion prompts filtered by status.
func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
"FROM expansion_prompts"
var args []interface{}
if status != "" {
query += " WHERE status = ?"
args = append(args, status)
}
query += " ORDER BY priority, idx"
if limit > 0 {
query += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query expansion_prompts: %w", err)
}
defer rows.Close()
var result []ExpansionPromptRow
for rows.Next() {
var r ExpansionPromptRow
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
return nil, fmt.Errorf("scan expansion_prompt row: %w", err)
}
result = append(result, r)
}
return result, rows.Err()
}
// CountExpansionPrompts returns counts by status.
func (db *DB) CountExpansionPrompts() (total int, pending int, err error) {
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts").Scan(&total)
if err != nil {
return 0, 0, fmt.Errorf("count expansion_prompts: %w", err)
}
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&pending)
if err != nil {
return total, 0, fmt.Errorf("count pending expansion_prompts: %w", err)
}
return total, pending, nil
}
// UpdateExpansionStatus updates the status of an expansion prompt by idx.
func (db *DB) UpdateExpansionStatus(idx int64, status string) error {
_, err := db.conn.Exec("UPDATE expansion_prompts SET status = ? WHERE idx = ?", status, idx)
if err != nil {
return fmt.Errorf("update expansion_prompt %d: %w", idx, err)
}
return nil
}
// QueryRows executes an arbitrary SQL query and returns results as maps.
func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interface{}, error) {
rows, err := db.conn.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("query: %w", err)
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("columns: %w", err)
}
var result []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(cols))
ptrs := make([]interface{}, len(cols))
for i := range values {
ptrs[i] = &values[i]
}
if err := rows.Scan(ptrs...); err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
row := make(map[string]interface{}, len(cols))
for i, col := range cols {
row[col] = values[i]
}
result = append(result, row)
}
return result, rows.Err()
}
// EnsureScoringTables creates the scoring tables if they don't exist.
func (db *DB) EnsureScoringTables() {
db.conn.Exec(`CREATE TABLE IF NOT EXISTS checkpoint_scores (
model TEXT, run_id TEXT, label TEXT, iteration INTEGER,
correct INTEGER, total INTEGER, accuracy DOUBLE,
scored_at TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (run_id, label)
)`)
db.conn.Exec(`CREATE TABLE IF NOT EXISTS probe_results (
model TEXT, run_id TEXT, label TEXT, probe_id TEXT,
passed BOOLEAN, response TEXT, iteration INTEGER,
scored_at TIMESTAMP DEFAULT current_timestamp,
PRIMARY KEY (run_id, label, probe_id)
)`)
db.conn.Exec(`CREATE TABLE IF NOT EXISTS scoring_results (
model TEXT, prompt_id TEXT, suite TEXT,
dimension TEXT, score DOUBLE,
scored_at TIMESTAMP DEFAULT current_timestamp
)`)
}
// WriteScoringResult writes a single scoring dimension result to DuckDB.
func (db *DB) WriteScoringResult(model, promptID, suite, dimension string, score float64) error {
_, err := db.conn.Exec(
`INSERT INTO scoring_results (model, prompt_id, suite, dimension, score) VALUES (?, ?, ?, ?, ?)`,
model, promptID, suite, dimension, score,
)
return err
}
// TableCounts returns row counts for all known tables.
func (db *DB) TableCounts() (map[string]int, error) {
tables := []string{"golden_set", "expansion_prompts", "seeds", "prompts",
"training_examples", "gemini_responses", "benchmark_questions", "benchmark_results", "validations",
"checkpoint_scores", "probe_results", "scoring_results"}
counts := make(map[string]int)
for _, t := range tables {
var count int
err := db.conn.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", t)).Scan(&count)
if err != nil {
continue
}
counts[t] = count
}
return counts, nil
}

153
pkg/ml/expand.go Normal file
View file

@ -0,0 +1,153 @@
package ml
import (
"context"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"time"
)
// ExpandOutput is the JSONL output structure for expansion generation.
type ExpandOutput struct {
ID string `json:"id"`
Domain string `json:"domain,omitempty"`
Prompt string `json:"prompt"`
Response string `json:"response"`
Model string `json:"model"`
ElapsedSeconds float64 `json:"elapsed_seconds"`
Chars int `json:"chars"`
}
// GetCompletedIDs queries InfluxDB for prompt IDs that have already been
// processed in the expansion_gen measurement.
func GetCompletedIDs(influx *InfluxClient) (map[string]bool, error) {
rows, err := influx.QuerySQL("SELECT DISTINCT seed_id FROM expansion_gen")
if err != nil {
return nil, fmt.Errorf("query expansion_gen: %w", err)
}
ids := make(map[string]bool, len(rows))
for _, row := range rows {
id := strVal(row, "seed_id")
if id != "" {
ids[id] = true
}
}
return ids, nil
}
// ExpandPrompts generates responses for expansion prompts using the given
// backend and reports progress to InfluxDB. Already-completed prompts (per
// InfluxDB) are skipped. API errors for individual prompts are logged and
// skipped. InfluxDB reporting is best-effort.
func ExpandPrompts(ctx context.Context, backend Backend, influx *InfluxClient, prompts []Response,
modelName, worker, outputDir string, dryRun bool, limit int) error {
remaining := prompts
// Check InfluxDB for already-completed IDs.
completed, err := GetCompletedIDs(influx)
if err != nil {
log.Printf("warning: could not check completed IDs: %v", err)
} else {
remaining = nil
for _, p := range prompts {
if !completed[p.ID] {
remaining = append(remaining, p)
}
}
skipped := len(prompts) - len(remaining)
if skipped > 0 {
log.Printf("skipping %d already-completed prompts, %d remaining", skipped, len(remaining))
}
}
if limit > 0 && limit < len(remaining) {
remaining = remaining[:limit]
}
if len(remaining) == 0 {
log.Println("all prompts already completed, nothing to do")
return nil
}
if dryRun {
log.Printf("dry-run: would process %d prompts with model %s (worker: %s)", len(remaining), modelName, worker)
for i, p := range remaining {
if i >= 10 {
log.Printf(" ... and %d more", len(remaining)-10)
break
}
log.Printf(" %s (domain: %s)", p.ID, p.Domain)
}
return nil
}
outputPath := filepath.Join(outputDir, fmt.Sprintf("expand-%s.jsonl", worker))
f, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("open output file: %w", err)
}
defer f.Close()
total := len(remaining)
completedCount := 0
for idx, p := range remaining {
start := time.Now()
response, err := backend.Generate(ctx, p.Prompt, GenOpts{Temperature: 0.7, MaxTokens: 2048})
elapsed := time.Since(start).Seconds()
if err != nil {
log.Printf("[%d/%d] id=%s ERROR: %v", idx+1, total, p.ID, err)
continue
}
chars := len(response)
completedCount++
out := ExpandOutput{
ID: p.ID,
Domain: p.Domain,
Prompt: p.Prompt,
Response: response,
Model: modelName,
ElapsedSeconds: elapsed,
Chars: chars,
}
line, err := json.Marshal(out)
if err != nil {
log.Printf("[%d/%d] id=%s marshal error: %v", idx+1, total, p.ID, err)
continue
}
if _, err := f.Write(append(line, '\n')); err != nil {
log.Printf("[%d/%d] id=%s write error: %v", idx+1, total, p.ID, err)
continue
}
genLine := fmt.Sprintf("expansion_gen,i=%d,w=%s,d=%s seed_id=\"%s\",gen_time=%f,chars=%di,model=\"%s\"",
idx, EscapeLp(worker), EscapeLp(p.Domain),
p.ID, elapsed, chars, modelName)
pct := float64(completedCount) / float64(total) * 100.0
progressLine := fmt.Sprintf("expansion_progress,worker=%s completed=%di,target=%di,pct=%f",
EscapeLp(worker), completedCount, total, pct)
if writeErr := influx.WriteLp([]string{genLine, progressLine}); writeErr != nil {
log.Printf("[%d/%d] id=%s influx write error: %v", idx+1, total, p.ID, writeErr)
}
log.Printf("[%d/%d] id=%s chars=%d time=%.1fs", idx+1, total, p.ID, chars, elapsed)
}
log.Printf("expand complete: %d/%d prompts generated, output: %s", completedCount, total, outputPath)
return nil
}

112
pkg/ml/export.go Normal file
View file

@ -0,0 +1,112 @@
package ml
import (
"bufio"
"encoding/json"
"fmt"
"math/rand"
"os"
"strings"
)
// ChatMessage is a single message in the chat training format.
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
// TrainingExample is a single training example in chat JSONL format.
type TrainingExample struct {
Messages []ChatMessage `json:"messages"`
}
// ValidatePercentages checks that train+valid+test percentages sum to 100
// and that none are negative.
func ValidatePercentages(trainPct, validPct, testPct int) error {
if trainPct < 0 || validPct < 0 || testPct < 0 {
return fmt.Errorf("percentages must be non-negative: train=%d, valid=%d, test=%d", trainPct, validPct, testPct)
}
sum := trainPct + validPct + testPct
if sum != 100 {
return fmt.Errorf("percentages must sum to 100, got %d (train=%d + valid=%d + test=%d)", sum, trainPct, validPct, testPct)
}
return nil
}
// FilterResponses removes responses with empty content, "ERROR:" prefix,
// or response length < 50 characters.
func FilterResponses(responses []Response) []Response {
var filtered []Response
for _, r := range responses {
if r.Response == "" {
continue
}
if strings.HasPrefix(r.Response, "ERROR:") {
continue
}
if len(r.Response) < 50 {
continue
}
filtered = append(filtered, r)
}
return filtered
}
// SplitData shuffles responses with a deterministic seed and splits them
// into train, valid, and test sets by the given percentages.
func SplitData(responses []Response, trainPct, validPct, testPct int, seed int64) (train, valid, test []Response) {
shuffled := make([]Response, len(responses))
copy(shuffled, responses)
rng := rand.New(rand.NewSource(seed))
rng.Shuffle(len(shuffled), func(i, j int) {
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
})
n := len(shuffled)
trainN := n * trainPct / 100
validN := n * validPct / 100
_ = testPct
train = shuffled[:trainN]
valid = shuffled[trainN : trainN+validN]
test = shuffled[trainN+validN:]
return train, valid, test
}
// WriteTrainingJSONL writes responses in chat JSONL format suitable for
// MLX LoRA fine-tuning.
func WriteTrainingJSONL(path string, responses []Response) error {
f, err := os.Create(path)
if err != nil {
return fmt.Errorf("create %s: %w", path, err)
}
defer f.Close()
w := bufio.NewWriter(f)
defer w.Flush()
for _, r := range responses {
example := TrainingExample{
Messages: []ChatMessage{
{Role: "user", Content: r.Prompt},
{Role: "assistant", Content: r.Response},
},
}
data, err := json.Marshal(example)
if err != nil {
return fmt.Errorf("marshal example: %w", err)
}
if _, err := w.Write(data); err != nil {
return fmt.Errorf("write line: %w", err)
}
if _, err := w.WriteString("\n"); err != nil {
return fmt.Errorf("write newline: %w", err)
}
}
return nil
}

369
pkg/ml/gguf.go Normal file
View file

@ -0,0 +1,369 @@
package ml
import (
"encoding/binary"
"encoding/json"
"fmt"
"log"
"math"
"os"
"regexp"
"sort"
"strconv"
"strings"
)
// GGUF format constants.
const (
ggufMagic = 0x46554747 // "GGUF" little-endian
ggufVersion = 3
ggufAlignment = 32
)
// GGUF metadata value types.
const (
ggufTypeUint32 = 4
ggufTypeFloat32 = 6
ggufTypeString = 8
)
// GGML tensor data types.
const (
ggmlTypeF32 = 0
ggmlTypeF16 = 1
ggmlTypeBF16 = 30
)
// ggufMetadata is a key-value pair in the GGUF header.
type ggufMetadata struct {
key string
valueType uint32
value interface{} // string, uint32, or float32
}
// ggufTensor describes a tensor in the GGUF file.
type ggufTensor struct {
name string
dims []uint64
dtype uint32 // ggmlType*
data []byte
}
// gemma3ModuleMap maps HuggingFace module names to GGUF tensor names.
var gemma3ModuleMap = map[string]string{
"self_attn.q_proj": "attn_q",
"self_attn.k_proj": "attn_k",
"self_attn.v_proj": "attn_v",
"self_attn.o_proj": "attn_output",
"mlp.gate_proj": "ffn_gate",
"mlp.up_proj": "ffn_up",
"mlp.down_proj": "ffn_down",
}
var mlxLoraKeyRe = regexp.MustCompile(`^model\.layers\.(\d+)\.(.*?)\.(lora_[ab])$`)
// MLXTensorToGGUF converts an MLX LoRA tensor name to GGUF LoRA tensor name.
// Input: "model.layers.0.self_attn.q_proj.lora_a"
// Output: "blk.0.attn_q.weight.lora_a"
func MLXTensorToGGUF(mlxName string) (string, error) {
m := mlxLoraKeyRe.FindStringSubmatch(mlxName)
if m == nil {
return "", fmt.Errorf("unrecognised MLX LoRA key: %s", mlxName)
}
layerNum := m[1]
module := m[2]
loraSuffix := m[3]
ggufModule, ok := gemma3ModuleMap[module]
if !ok {
return "", fmt.Errorf("unknown module %q in %s", module, mlxName)
}
return fmt.Sprintf("blk.%s.%s.weight.%s", layerNum, ggufModule, loraSuffix), nil
}
// SafetensorsDtypeToGGML maps safetensors dtype strings to GGML types.
func SafetensorsDtypeToGGML(dtype string) (uint32, error) {
switch dtype {
case "F32":
return ggmlTypeF32, nil
case "F16":
return ggmlTypeF16, nil
case "BF16":
return ggmlTypeBF16, nil
default:
return 0, fmt.Errorf("unsupported dtype %q for GGUF", dtype)
}
}
// ConvertMLXtoGGUFLoRA converts an MLX LoRA adapter to GGUF LoRA format.
func ConvertMLXtoGGUFLoRA(safetensorsPath, configPath, outputPath, architecture string) error {
cfgData, err := os.ReadFile(configPath)
if err != nil {
return fmt.Errorf("read config: %w", err)
}
var mlxConfig struct {
LoraParameters struct {
Rank int `json:"rank"`
Scale float64 `json:"scale"`
} `json:"lora_parameters"`
}
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
return fmt.Errorf("parse config: %w", err)
}
rank := mlxConfig.LoraParameters.Rank
if rank == 0 {
rank = 8
}
scale := mlxConfig.LoraParameters.Scale
if scale == 0 {
scale = 20.0
}
loraAlpha := float32(math.Round(scale * float64(rank)))
tensors, tensorData, err := ReadSafetensors(safetensorsPath)
if err != nil {
return fmt.Errorf("read safetensors: %w", err)
}
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
var ggufTensors []ggufTensor
for mlxKey, info := range tensors {
ggufName, err := MLXTensorToGGUF(mlxKey)
if err != nil {
return err
}
ggmlType, err := SafetensorsDtypeToGGML(info.Dtype)
if err != nil {
return fmt.Errorf("tensor %s: %w", mlxKey, err)
}
data := GetTensorData(info, tensorData)
if len(info.Shape) == 2 {
rows, cols := info.Shape[0], info.Shape[1]
switch info.Dtype {
case "F32":
data = TransposeFloat32(data, rows, cols)
case "F16":
data = TransposeFloat16(data, rows, cols)
case "BF16":
data = TransposeBFloat16(data, rows, cols)
}
ggufTensors = append(ggufTensors, ggufTensor{
name: ggufName,
dims: []uint64{uint64(rows), uint64(cols)},
dtype: ggmlType,
data: data,
})
} else {
dims := make([]uint64, len(info.Shape))
for i, s := range info.Shape {
dims[i] = uint64(s)
}
ggufTensors = append(ggufTensors, ggufTensor{
name: ggufName,
dims: dims,
dtype: ggmlType,
data: data,
})
}
}
sort.Slice(ggufTensors, func(i, j int) bool {
return ggufTensors[i].name < ggufTensors[j].name
})
metadata := []ggufMetadata{
{key: "general.type", valueType: ggufTypeString, value: "adapter"},
{key: "general.architecture", valueType: ggufTypeString, value: architecture},
{key: "adapter.type", valueType: ggufTypeString, value: "lora"},
{key: "adapter.lora.alpha", valueType: ggufTypeFloat32, value: loraAlpha},
}
if err := writeGGUF(outputPath, metadata, ggufTensors); err != nil {
return fmt.Errorf("write GGUF: %w", err)
}
log.Printf("wrote GGUF LoRA: %s (%d tensors, alpha=%.0f)", outputPath, len(ggufTensors), loraAlpha)
return nil
}
// writeGGUF writes a GGUF v3 file.
func writeGGUF(path string, metadata []ggufMetadata, tensors []ggufTensor) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
w := &ggufWriter{f: f}
w.writeUint32(ggufMagic)
w.writeUint32(ggufVersion)
w.writeUint64(uint64(len(tensors)))
w.writeUint64(uint64(len(metadata)))
for _, kv := range metadata {
w.writeString(kv.key)
w.writeUint32(kv.valueType)
switch kv.valueType {
case ggufTypeString:
w.writeString(kv.value.(string))
case ggufTypeUint32:
w.writeUint32(kv.value.(uint32))
case ggufTypeFloat32:
w.writeFloat32(kv.value.(float32))
}
}
dataOffset := uint64(0)
for _, t := range tensors {
w.writeString(t.name)
w.writeUint32(uint32(len(t.dims)))
for _, d := range t.dims {
w.writeUint64(d)
}
w.writeUint32(t.dtype)
w.writeUint64(dataOffset)
dataOffset += uint64(len(t.data))
if rem := dataOffset % ggufAlignment; rem != 0 {
dataOffset += ggufAlignment - rem
}
}
pos := w.pos
if rem := pos % ggufAlignment; rem != 0 {
pad := ggufAlignment - rem
w.writeBytes(make([]byte, pad))
}
for _, t := range tensors {
w.writeBytes(t.data)
if rem := uint64(len(t.data)) % ggufAlignment; rem != 0 {
w.writeBytes(make([]byte, ggufAlignment-rem))
}
}
return w.err
}
// ggufWriter tracks position and accumulates errors.
type ggufWriter struct {
f *os.File
pos uint64
err error
}
func (w *ggufWriter) writeBytes(b []byte) {
if w.err != nil {
return
}
n, err := w.f.Write(b)
w.pos += uint64(n)
if err != nil {
w.err = err
}
}
func (w *ggufWriter) writeUint32(v uint32) {
b := make([]byte, 4)
binary.LittleEndian.PutUint32(b, v)
w.writeBytes(b)
}
func (w *ggufWriter) writeUint64(v uint64) {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, v)
w.writeBytes(b)
}
func (w *ggufWriter) writeFloat32(v float32) {
w.writeUint32(math.Float32bits(v))
}
func (w *ggufWriter) writeString(s string) {
w.writeUint64(uint64(len(s)))
w.writeBytes([]byte(s))
}
// DetectArchFromConfig tries to infer the model architecture from adapter_config.json.
func DetectArchFromConfig(configPath string) string {
data, err := os.ReadFile(configPath)
if err != nil {
return "gemma3"
}
var cfg struct {
LoraParameters struct {
Rank int `json:"rank"`
} `json:"lora_parameters"`
}
json.Unmarshal(data, &cfg)
return "gemma3"
}
// ArchitectureGGUFMap maps model tags to GGUF architecture names.
var ArchitectureGGUFMap = map[string]string{
"gemma-3-1b": "gemma3",
"gemma-3-4b": "gemma3",
"gemma-3-12b": "gemma3",
"gemma-3-27b": "gemma3",
}
// ModelTagToGGUFArch returns the GGUF architecture for a model tag.
func ModelTagToGGUFArch(modelTag string) string {
if arch, ok := ArchitectureGGUFMap[modelTag]; ok {
return arch
}
return "gemma3"
}
// GGUFModelBlobPath returns the path to the GGUF model blob in Ollama's store.
func GGUFModelBlobPath(ollamaModelsDir, modelName string) (string, error) {
parts := strings.SplitN(modelName, ":", 2)
family := parts[0]
tag := "latest"
if len(parts) > 1 {
tag = parts[1]
}
manifestPath := fmt.Sprintf("%s/manifests/registry.ollama.ai/library/%s/%s", ollamaModelsDir, family, tag)
data, err := os.ReadFile(manifestPath)
if err != nil {
return "", fmt.Errorf("read manifest %s: %w", manifestPath, err)
}
var manifest struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.Unmarshal(data, &manifest); err != nil {
return "", fmt.Errorf("parse manifest: %w", err)
}
for _, layer := range manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
blobName := strings.Replace(layer.Digest, ":", "-", 1)
return fmt.Sprintf("%s/blobs/%s", ollamaModelsDir, blobName), nil
}
}
return "", fmt.Errorf("no model layer found in manifest for %s", modelName)
}
// ParseLayerFromTensorName extracts the layer number from a GGUF tensor name.
func ParseLayerFromTensorName(name string) (int, error) {
re := regexp.MustCompile(`blk\.(\d+)\.`)
m := re.FindStringSubmatch(name)
if m == nil {
return 0, fmt.Errorf("no layer number in %s", name)
}
return strconv.Atoi(m[1])
}

132
pkg/ml/influx.go Normal file
View file

@ -0,0 +1,132 @@
package ml
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
)
// InfluxClient talks to an InfluxDB v3 instance.
type InfluxClient struct {
url string
db string
token string
}
// NewInfluxClient creates an InfluxClient for the given URL and database.
// Reads token from INFLUX_TOKEN env var first, then ~/.influx_token file.
// If url is empty, defaults to "http://10.69.69.165:8181".
// If db is empty, defaults to "training".
func NewInfluxClient(url, db string) *InfluxClient {
if url == "" {
url = "http://10.69.69.165:8181"
}
if db == "" {
db = "training"
}
token := os.Getenv("INFLUX_TOKEN")
if token == "" {
home, err := os.UserHomeDir()
if err == nil {
data, err := os.ReadFile(filepath.Join(home, ".influx_token"))
if err == nil {
token = strings.TrimSpace(string(data))
}
}
}
return &InfluxClient{
url: url,
db: db,
token: token,
}
}
// WriteLp writes line protocol data to InfluxDB.
func (c *InfluxClient) WriteLp(lines []string) error {
body := strings.Join(lines, "\n")
url := fmt.Sprintf("%s/api/v3/write_lp?db=%s", c.url, c.db)
req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body))
if err != nil {
return fmt.Errorf("create write request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("Content-Type", "text/plain")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("write request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("write failed %d: %s", resp.StatusCode, string(respBody))
}
return nil
}
// QuerySQL runs a SQL query against InfluxDB and returns the result rows.
func (c *InfluxClient) QuerySQL(sql string) ([]map[string]interface{}, error) {
reqBody := map[string]string{
"db": c.db,
"q": sql,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal query request: %w", err)
}
url := c.url + "/api/v3/query_sql"
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
if err != nil {
return nil, fmt.Errorf("create query request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("query request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read query response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("query failed %d: %s", resp.StatusCode, string(respBody))
}
var rows []map[string]interface{}
if err := json.Unmarshal(respBody, &rows); err != nil {
return nil, fmt.Errorf("unmarshal query response: %w", err)
}
return rows, nil
}
// EscapeLp escapes spaces, commas, and equals signs for InfluxDB line protocol
// tag values.
func EscapeLp(s string) string {
s = strings.ReplaceAll(s, `,`, `\,`)
s = strings.ReplaceAll(s, `=`, `\=`)
s = strings.ReplaceAll(s, ` `, `\ `)
return s
}

152
pkg/ml/ollama.go Normal file
View file

@ -0,0 +1,152 @@
package ml
import (
"bytes"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"time"
)
// OllamaBaseModelMap maps model tags to Ollama model names.
var OllamaBaseModelMap = map[string]string{
"gemma-3-1b": "gemma3:1b",
"gemma-3-4b": "gemma3:4b",
"gemma-3-12b": "gemma3:12b",
"gemma-3-27b": "gemma3:27b",
}
// HFBaseModelMap maps model tags to HuggingFace model IDs.
var HFBaseModelMap = map[string]string{
"gemma-3-1b": "google/gemma-3-1b-it",
"gemma-3-4b": "google/gemma-3-4b-it",
"gemma-3-12b": "google/gemma-3-12b-it",
"gemma-3-27b": "google/gemma-3-27b-it",
}
// ollamaUploadBlob uploads a local file to Ollama's blob store.
// Returns the sha256 digest string (e.g. "sha256:abc123...").
func ollamaUploadBlob(ollamaURL, filePath string) (string, error) {
data, err := os.ReadFile(filePath)
if err != nil {
return "", fmt.Errorf("read %s: %w", filePath, err)
}
hash := sha256.Sum256(data)
digest := "sha256:" + hex.EncodeToString(hash[:])
headReq, _ := http.NewRequest(http.MethodHead, ollamaURL+"/api/blobs/"+digest, nil)
client := &http.Client{Timeout: 5 * time.Minute}
headResp, err := client.Do(headReq)
if err == nil && headResp.StatusCode == http.StatusOK {
headResp.Body.Close()
return digest, nil
}
if headResp != nil {
headResp.Body.Close()
}
req, err := http.NewRequest(http.MethodPost, ollamaURL+"/api/blobs/"+digest, bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("blob request: %w", err)
}
req.Header.Set("Content-Type", "application/octet-stream")
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("blob upload: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("blob upload HTTP %d: %s", resp.StatusCode, string(body))
}
return digest, nil
}
// OllamaCreateModel creates a temporary Ollama model with a LoRA adapter.
// peftDir is a local directory containing adapter_model.safetensors and adapter_config.json.
func OllamaCreateModel(ollamaURL, modelName, baseModel, peftDir string) error {
sfPath := peftDir + "/adapter_model.safetensors"
cfgPath := peftDir + "/adapter_config.json"
sfDigest, err := ollamaUploadBlob(ollamaURL, sfPath)
if err != nil {
return fmt.Errorf("upload adapter safetensors: %w", err)
}
cfgDigest, err := ollamaUploadBlob(ollamaURL, cfgPath)
if err != nil {
return fmt.Errorf("upload adapter config: %w", err)
}
reqBody, _ := json.Marshal(map[string]interface{}{
"model": modelName,
"from": baseModel,
"adapters": map[string]string{
"adapter_model.safetensors": sfDigest,
"adapter_config.json": cfgDigest,
},
})
client := &http.Client{Timeout: 10 * time.Minute}
resp, err := client.Post(ollamaURL+"/api/create", "application/json", bytes.NewReader(reqBody))
if err != nil {
return fmt.Errorf("ollama create: %w", err)
}
defer resp.Body.Close()
decoder := json.NewDecoder(resp.Body)
for decoder.More() {
var status struct {
Status string `json:"status"`
Error string `json:"error"`
}
if err := decoder.Decode(&status); err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("ollama create decode: %w", err)
}
if status.Error != "" {
return fmt.Errorf("ollama create: %s", status.Error)
}
if status.Status == "success" {
return nil
}
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("ollama create: HTTP %d", resp.StatusCode)
}
return nil
}
// OllamaDeleteModel removes a temporary Ollama model.
func OllamaDeleteModel(ollamaURL, modelName string) error {
body, _ := json.Marshal(map[string]string{"model": modelName})
req, err := http.NewRequest(http.MethodDelete, ollamaURL+"/api/delete", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("ollama delete request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 30 * time.Second}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("ollama delete: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf("ollama delete %d: %s", resp.StatusCode, string(respBody))
}
return nil
}

137
pkg/ml/parquet.go Normal file
View file

@ -0,0 +1,137 @@
package ml
import (
"bufio"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/parquet-go/parquet-go"
)
// ParquetRow is the schema for exported Parquet files.
type ParquetRow struct {
Prompt string `parquet:"prompt"`
Response string `parquet:"response"`
System string `parquet:"system"`
Messages string `parquet:"messages"`
}
// ExportParquet reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl)
// from trainingDir and writes Parquet files with snappy compression to outputDir.
// Returns total rows exported.
func ExportParquet(trainingDir, outputDir string) (int, error) {
if outputDir == "" {
outputDir = filepath.Join(trainingDir, "parquet")
}
if err := os.MkdirAll(outputDir, 0755); err != nil {
return 0, fmt.Errorf("create output dir: %w", err)
}
total := 0
for _, split := range []string{"train", "valid", "test"} {
jsonlPath := filepath.Join(trainingDir, split+".jsonl")
if _, err := os.Stat(jsonlPath); os.IsNotExist(err) {
continue
}
n, err := ExportSplitParquet(jsonlPath, outputDir, split)
if err != nil {
return total, fmt.Errorf("export %s: %w", split, err)
}
total += n
}
return total, nil
}
// ExportSplitParquet reads a chat JSONL file and writes a Parquet file for the
// given split name. Returns the number of rows written.
func ExportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
f, err := os.Open(jsonlPath)
if err != nil {
return 0, fmt.Errorf("open %s: %w", jsonlPath, err)
}
defer f.Close()
var rows []ParquetRow
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" {
continue
}
var data struct {
Messages []ChatMessage `json:"messages"`
}
if err := json.Unmarshal([]byte(text), &data); err != nil {
continue
}
var prompt, response, system string
for _, m := range data.Messages {
switch m.Role {
case "user":
if prompt == "" {
prompt = m.Content
}
case "assistant":
if response == "" {
response = m.Content
}
case "system":
if system == "" {
system = m.Content
}
}
}
msgsJSON, _ := json.Marshal(data.Messages)
rows = append(rows, ParquetRow{
Prompt: prompt,
Response: response,
System: system,
Messages: string(msgsJSON),
})
}
if err := scanner.Err(); err != nil {
return 0, fmt.Errorf("scan %s: %w", jsonlPath, err)
}
if len(rows) == 0 {
return 0, nil
}
outPath := filepath.Join(outputDir, split+".parquet")
out, err := os.Create(outPath)
if err != nil {
return 0, fmt.Errorf("create %s: %w", outPath, err)
}
writer := parquet.NewGenericWriter[ParquetRow](out,
parquet.Compression(&parquet.Snappy),
)
if _, err := writer.Write(rows); err != nil {
out.Close()
return 0, fmt.Errorf("write parquet rows: %w", err)
}
if err := writer.Close(); err != nil {
out.Close()
return 0, fmt.Errorf("close parquet writer: %w", err)
}
if err := out.Close(); err != nil {
return 0, fmt.Errorf("close file: %w", err)
}
return len(rows), nil
}

212
pkg/ml/status.go Normal file
View file

@ -0,0 +1,212 @@
package ml
import (
"fmt"
"io"
"sort"
)
// trainingRow holds deduplicated training status + loss for a single model.
type trainingRow struct {
model string
status string
iteration int
totalIters int
pct float64
loss float64
hasLoss bool
}
// genRow holds deduplicated generation progress for a single worker.
type genRow struct {
worker string
completed int
target int
pct float64
}
// PrintStatus queries InfluxDB for training and generation progress and writes
// a formatted summary to w.
func PrintStatus(influx *InfluxClient, w io.Writer) error {
statusRows, err := influx.QuerySQL(
"SELECT model, run_id, status, iteration, total_iters, pct FROM training_status ORDER BY time DESC LIMIT 10",
)
if err != nil {
statusRows = nil
}
lossRows, err := influx.QuerySQL(
"SELECT model, loss_type, loss, iteration, tokens_per_sec FROM training_loss WHERE loss_type = 'train' ORDER BY time DESC LIMIT 10",
)
if err != nil {
lossRows = nil
}
goldenRows, err := influx.QuerySQL(
"SELECT worker, completed, target, pct FROM golden_gen_progress ORDER BY time DESC LIMIT 5",
)
if err != nil {
goldenRows = nil
}
expansionRows, err := influx.QuerySQL(
"SELECT worker, completed, target, pct FROM expansion_progress ORDER BY time DESC LIMIT 5",
)
if err != nil {
expansionRows = nil
}
training := dedupeTraining(statusRows, lossRows)
golden := dedupeGeneration(goldenRows)
expansion := dedupeGeneration(expansionRows)
fmt.Fprintln(w, "Training:")
if len(training) == 0 {
fmt.Fprintln(w, " (no data)")
} else {
for _, tr := range training {
progress := fmt.Sprintf("%d/%d", tr.iteration, tr.totalIters)
pct := fmt.Sprintf("%.1f%%", tr.pct)
if tr.hasLoss {
fmt.Fprintf(w, " %-13s %-9s %9s %7s loss=%.3f\n",
tr.model, tr.status, progress, pct, tr.loss)
} else {
fmt.Fprintf(w, " %-13s %-9s %9s %7s\n",
tr.model, tr.status, progress, pct)
}
}
}
fmt.Fprintln(w)
fmt.Fprintln(w, "Generation:")
hasGenData := false
if len(golden) > 0 {
hasGenData = true
for _, g := range golden {
progress := fmt.Sprintf("%d/%d", g.completed, g.target)
pct := fmt.Sprintf("%.1f%%", g.pct)
fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "golden", progress, pct, g.worker)
}
}
if len(expansion) > 0 {
hasGenData = true
for _, g := range expansion {
progress := fmt.Sprintf("%d/%d", g.completed, g.target)
pct := fmt.Sprintf("%.1f%%", g.pct)
fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "expansion", progress, pct, g.worker)
}
}
if !hasGenData {
fmt.Fprintln(w, " (no data)")
}
return nil
}
// dedupeTraining merges training status and loss rows, keeping only the first
// (latest) row per model.
func dedupeTraining(statusRows, lossRows []map[string]interface{}) []trainingRow {
lossMap := make(map[string]float64)
lossSeenMap := make(map[string]bool)
for _, row := range lossRows {
model := strVal(row, "model")
if model == "" || lossSeenMap[model] {
continue
}
lossSeenMap[model] = true
lossMap[model] = floatVal(row, "loss")
}
seen := make(map[string]bool)
var rows []trainingRow
for _, row := range statusRows {
model := strVal(row, "model")
if model == "" || seen[model] {
continue
}
seen[model] = true
tr := trainingRow{
model: model,
status: strVal(row, "status"),
iteration: intVal(row, "iteration"),
totalIters: intVal(row, "total_iters"),
pct: floatVal(row, "pct"),
}
if loss, ok := lossMap[model]; ok {
tr.loss = loss
tr.hasLoss = true
}
rows = append(rows, tr)
}
sort.Slice(rows, func(i, j int) bool {
return rows[i].model < rows[j].model
})
return rows
}
// dedupeGeneration deduplicates generation progress rows by worker.
func dedupeGeneration(rows []map[string]interface{}) []genRow {
seen := make(map[string]bool)
var result []genRow
for _, row := range rows {
worker := strVal(row, "worker")
if worker == "" || seen[worker] {
continue
}
seen[worker] = true
result = append(result, genRow{
worker: worker,
completed: intVal(row, "completed"),
target: intVal(row, "target"),
pct: floatVal(row, "pct"),
})
}
sort.Slice(result, func(i, j int) bool {
return result[i].worker < result[j].worker
})
return result
}
// strVal extracts a string value from a row map.
func strVal(row map[string]interface{}, key string) string {
v, ok := row[key]
if !ok {
return ""
}
s, ok := v.(string)
if !ok {
return ""
}
return s
}
// floatVal extracts a float64 value from a row map.
func floatVal(row map[string]interface{}, key string) float64 {
v, ok := row[key]
if !ok {
return 0
}
f, ok := v.(float64)
if !ok {
return 0
}
return f
}
// intVal extracts an integer value from a row map. InfluxDB JSON returns all
// numbers as float64, so this truncates to int.
func intVal(row map[string]interface{}, key string) int {
return int(floatVal(row, key))
}

403
pkg/ml/worker.go Normal file
View file

@ -0,0 +1,403 @@
package ml
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"time"
)
// WorkerConfig holds the worker's runtime configuration.
type WorkerConfig struct {
APIBase string
WorkerID string
Name string
APIKey string
GPUType string
VRAMGb int
Languages []string
Models []string
InferURL string
TaskType string
BatchSize int
PollInterval time.Duration
OneShot bool
DryRun bool
}
// APITask represents a task from the LEM API.
type APITask struct {
ID int `json:"id"`
TaskType string `json:"task_type"`
Status string `json:"status"`
Language string `json:"language"`
Domain string `json:"domain"`
ModelName string `json:"model_name"`
PromptID string `json:"prompt_id"`
PromptText string `json:"prompt_text"`
Config *struct {
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
} `json:"config"`
Priority int `json:"priority"`
}
// RunWorkerLoop is the main worker loop that polls for tasks and processes them.
func RunWorkerLoop(cfg *WorkerConfig) {
log.Printf("LEM Worker starting")
log.Printf(" ID: %s", cfg.WorkerID)
log.Printf(" Name: %s", cfg.Name)
log.Printf(" API: %s", cfg.APIBase)
log.Printf(" Infer: %s", cfg.InferURL)
log.Printf(" GPU: %s (%d GB)", cfg.GPUType, cfg.VRAMGb)
log.Printf(" Langs: %v", cfg.Languages)
log.Printf(" Models: %v", cfg.Models)
log.Printf(" Batch: %d", cfg.BatchSize)
log.Printf(" Dry-run: %v", cfg.DryRun)
if err := workerRegister(cfg); err != nil {
log.Fatalf("Registration failed: %v", err)
}
log.Println("Registered with LEM API")
for {
processed := workerPoll(cfg)
if cfg.OneShot {
log.Printf("One-shot mode: processed %d tasks, exiting", processed)
return
}
if processed == 0 {
log.Printf("No tasks available, sleeping %v", cfg.PollInterval)
time.Sleep(cfg.PollInterval)
}
workerHeartbeat(cfg)
}
}
func workerRegister(cfg *WorkerConfig) error {
body := map[string]interface{}{
"worker_id": cfg.WorkerID,
"name": cfg.Name,
"version": "0.1.0",
"os": runtime.GOOS,
"arch": runtime.GOARCH,
}
if cfg.GPUType != "" {
body["gpu_type"] = cfg.GPUType
}
if cfg.VRAMGb > 0 {
body["vram_gb"] = cfg.VRAMGb
}
if len(cfg.Languages) > 0 {
body["languages"] = cfg.Languages
}
if len(cfg.Models) > 0 {
body["supported_models"] = cfg.Models
}
_, err := apiPost(cfg, "/api/lem/workers/register", body)
return err
}
func workerHeartbeat(cfg *WorkerConfig) {
body := map[string]interface{}{
"worker_id": cfg.WorkerID,
}
apiPost(cfg, "/api/lem/workers/heartbeat", body)
}
func workerPoll(cfg *WorkerConfig) int {
url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.WorkerID, cfg.BatchSize)
if cfg.TaskType != "" {
url += "&type=" + cfg.TaskType
}
resp, err := apiGet(cfg, url)
if err != nil {
log.Printf("Error fetching tasks: %v", err)
return 0
}
var result struct {
Tasks []APITask `json:"tasks"`
Count int `json:"count"`
}
if err := json.Unmarshal(resp, &result); err != nil {
log.Printf("Error parsing tasks: %v", err)
return 0
}
if result.Count == 0 {
return 0
}
log.Printf("Got %d tasks", result.Count)
processed := 0
for _, task := range result.Tasks {
if err := workerProcessTask(cfg, task); err != nil {
log.Printf("Task %d failed: %v", task.ID, err)
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
"worker_id": cfg.WorkerID,
})
continue
}
processed++
}
return processed
}
func workerProcessTask(cfg *WorkerConfig, task APITask) error {
log.Printf("Processing task %d: %s [%s/%s] %d chars prompt",
task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText))
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
"worker_id": cfg.WorkerID,
})
if err != nil {
return fmt.Errorf("claim: %w", err)
}
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
"worker_id": cfg.WorkerID,
"status": "in_progress",
})
if cfg.DryRun {
log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText)
return nil
}
start := time.Now()
response, err := workerInfer(cfg, task)
genTime := time.Since(start)
if err != nil {
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
"worker_id": cfg.WorkerID,
"status": "abandoned",
})
return fmt.Errorf("inference: %w", err)
}
modelUsed := task.ModelName
if modelUsed == "" {
modelUsed = "default"
}
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{
"worker_id": cfg.WorkerID,
"response_text": response,
"model_used": modelUsed,
"gen_time_ms": int(genTime.Milliseconds()),
})
if err != nil {
return fmt.Errorf("submit result: %w", err)
}
log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond))
return nil
}
func workerInfer(cfg *WorkerConfig, task APITask) (string, error) {
messages := []map[string]string{
{"role": "user", "content": task.PromptText},
}
temp := 0.7
maxTokens := 2048
if task.Config != nil {
if task.Config.Temperature > 0 {
temp = task.Config.Temperature
}
if task.Config.MaxTokens > 0 {
maxTokens = task.Config.MaxTokens
}
}
reqBody := map[string]interface{}{
"model": task.ModelName,
"messages": messages,
"temperature": temp,
"max_tokens": maxTokens,
}
data, err := json.Marshal(reqBody)
if err != nil {
return "", err
}
req, err := http.NewRequest("POST", cfg.InferURL+"/v1/chat/completions", bytes.NewReader(data))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 5 * time.Minute}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("inference request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
var chatResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &chatResp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(chatResp.Choices) == 0 {
return "", fmt.Errorf("no choices in response")
}
content := chatResp.Choices[0].Message.Content
if len(content) < 10 {
return "", fmt.Errorf("response too short: %d chars", len(content))
}
return content, nil
}
// HTTP helpers for the LEM API.
func apiGet(cfg *WorkerConfig, path string) ([]byte, error) {
req, err := http.NewRequest("GET", cfg.APIBase+path, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
return body, nil
}
func apiPost(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
return apiRequest(cfg, "POST", path, data)
}
func apiPatch(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
return apiRequest(cfg, "PATCH", path, data)
}
func apiDelete(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
return apiRequest(cfg, "DELETE", path, data)
}
func apiRequest(cfg *WorkerConfig, method, path string, data map[string]interface{}) ([]byte, error) {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, err
}
req, err := http.NewRequest(method, cfg.APIBase+path, bytes.NewReader(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
}
return body, nil
}
// MachineID returns the machine ID from /etc/machine-id or hostname fallback.
func MachineID() string {
if data, err := os.ReadFile("/etc/machine-id"); err == nil {
id := string(bytes.TrimSpace(data))
if len(id) > 0 {
return id
}
}
h, _ := os.Hostname()
return h
}
// Hostname returns the system hostname.
func Hostname() string {
h, _ := os.Hostname()
return h
}
// ReadKeyFile reads the LEM API key from ~/.config/lem/api_key.
func ReadKeyFile() string {
home, _ := os.UserHomeDir()
path := filepath.Join(home, ".config", "lem", "api_key")
data, err := os.ReadFile(path)
if err != nil {
return ""
}
return string(bytes.TrimSpace(data))
}
// SplitComma splits a comma-separated string into trimmed parts.
func SplitComma(s string) []string {
var result []string
for _, part := range bytes.Split([]byte(s), []byte(",")) {
trimmed := bytes.TrimSpace(part)
if len(trimmed) > 0 {
result = append(result, string(trimmed))
}
}
return result
}
func truncStr(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}