feat(ml): add format converters, data pipeline, and scoring agent
Port remaining lem-repo components into pkg/ml/: - convert.go: safetensors reader/writer, MLX→PEFT converter - gguf.go: GGUF v3 writer, MLX→GGUF LoRA converter - export.go: training data JSONL export with split/filter - parquet.go: Parquet export with snappy compression - db.go: DuckDB wrapper for golden set and expansion prompts - influx.go: InfluxDB v3 client for metrics/status - ollama.go: Ollama model management (create/delete with adapters) - status.go: training and generation status display - expand.go: expansion generation pipeline (Backend interface) - agent.go: scoring agent with probe running and InfluxDB push - worker.go: distributed worker for LEM API task processing Adds parquet-go and go-duckdb dependencies. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3fdc3f3086
commit
fcd1758b7d
13 changed files with 3332 additions and 1 deletions
16
go.mod
16
go.mod
|
|
@ -38,6 +38,8 @@ require (
|
|||
github.com/Snider/Enchantrix v0.0.2 // indirect
|
||||
github.com/TwiN/go-color v1.4.1 // indirect
|
||||
github.com/adrg/xdg v0.5.3 // indirect
|
||||
github.com/andybalholm/brotli v1.1.1 // indirect
|
||||
github.com/apache/arrow-go/v18 v18.1.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect
|
||||
|
|
@ -71,9 +73,11 @@ require (
|
|||
github.com/go-openapi/jsonpointer v0.22.4 // indirect
|
||||
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/godbus/dbus/v5 v5.2.2 // indirect
|
||||
github.com/gofrs/flock v0.12.1 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/flatbuffers v25.1.24+incompatible // indirect
|
||||
github.com/google/go-github/v39 v39.2.0 // indirect
|
||||
github.com/google/go-querystring v1.1.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
|
|
@ -85,11 +89,13 @@ require (
|
|||
github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/kevinburke/ssh_config v1.4.0 // indirect
|
||||
github.com/klauspost/compress v1.18.3 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
|
||||
github.com/leaanthony/u v1.1.1 // indirect
|
||||
github.com/lmittmann/tint v1.1.2 // indirect
|
||||
github.com/mailru/easyjson v0.9.1 // indirect
|
||||
github.com/marcboeker/go-duckdb v1.8.5 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect
|
||||
|
|
@ -97,8 +103,12 @@ require (
|
|||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect
|
||||
github.com/parquet-go/bitpack v1.0.0 // indirect
|
||||
github.com/parquet-go/jsonlite v1.0.0 // indirect
|
||||
github.com/parquet-go/parquet-go v0.27.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/perimeterx/marshmallow v1.1.5 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/pjbgf/sha1cd v0.5.0 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
|
|
@ -119,9 +129,9 @@ require (
|
|||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twpayne/go-geom v1.6.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
github.com/ulikunitz/xz v0.5.15 // indirect
|
||||
github.com/unpoller/unifi/v5 v5.17.0 // indirect
|
||||
github.com/wI2L/jsondiff v0.7.0 // indirect
|
||||
github.com/wailsapp/go-webview2 v1.0.23 // indirect
|
||||
github.com/wailsapp/wails/v3 v3.0.0-alpha.64 // indirect
|
||||
|
|
@ -130,10 +140,14 @@ require (
|
|||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/zeebo/xxh3 v1.1.0 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 // indirect
|
||||
golang.org/x/tools v0.41.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect
|
||||
google.golang.org/grpc v1.76.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
|
|
|
|||
33
go.sum
33
go.sum
|
|
@ -18,12 +18,17 @@ github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBi
|
|||
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
|
||||
github.com/Snider/Borg v0.2.0 h1:iCyDhY4WTXi39+FexRwXbn2YpZ2U9FUXVXDZk9xRCXQ=
|
||||
github.com/Snider/Borg v0.2.0/go.mod h1:TqlKnfRo9okioHbgrZPfWjQsztBV0Nfskz4Om1/vdMY=
|
||||
github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ=
|
||||
github.com/TwiN/go-color v1.4.1 h1:mqG0P/KBgHKVqmtL5ye7K0/Gr4l6hTksPgTgMk3mUzc=
|
||||
github.com/TwiN/go-color v1.4.1/go.mod h1:WcPf/jtiW95WBIsEeY1Lc/b8aaWoiqQpu5cf8WFxu+s=
|
||||
github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78=
|
||||
github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ=
|
||||
github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA=
|
||||
github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y=
|
||||
github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU=
|
||||
|
|
@ -115,6 +120,8 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L
|
|||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
|
||||
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
|
||||
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
|
||||
|
|
@ -127,6 +134,8 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
|
|||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o=
|
||||
github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
|
|
@ -153,6 +162,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm
|
|||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||
github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
|
||||
github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kluctl/go-embed-python v0.0.0-3.13.1-20241219-1 h1:x1cSEj4Ug5mpuZgUHLvUmlc5r//KHFn6iYiRSrRcVy4=
|
||||
|
|
@ -179,6 +190,8 @@ github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w=
|
|||
github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE=
|
||||
github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8=
|
||||
github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
|
||||
github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0=
|
||||
github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8=
|
||||
github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ=
|
||||
github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU=
|
||||
|
|
@ -206,10 +219,18 @@ github.com/ollama/ollama v0.15.4 h1:y841GH5lsi5j5BTFyX/E+UOC3Yiw+JBfdjBVRGw+I0M=
|
|||
github.com/ollama/ollama v0.15.4/go.mod h1:4Yn3jw2hZ4VqyJ1XciYawDRE8bzv4RT3JiVZR1kCfwE=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA=
|
||||
github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs=
|
||||
github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU=
|
||||
github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0=
|
||||
github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g=
|
||||
github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
|
||||
github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw=
|
||||
github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU=
|
||||
github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
|
||||
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
|
||||
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
|
|
@ -274,6 +295,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
|||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4=
|
||||
github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
||||
|
|
@ -292,10 +315,13 @@ github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQs
|
|||
github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc=
|
||||
github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs=
|
||||
github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
|
|
@ -314,6 +340,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
|
||||
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
|
||||
|
|
@ -323,6 +350,7 @@ golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHi
|
|||
golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
|
||||
golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
|
||||
|
|
@ -346,11 +374,14 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 h1:O1cMQHRfwNpDfDJerqRoE2oD+AFlyid87D40L/OkkJo=
|
||||
golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
|
||||
golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
|
||||
|
|
@ -359,6 +390,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
|
|||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY=
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
|
||||
|
|
|
|||
1070
pkg/ml/agent.go
Normal file
1070
pkg/ml/agent.go
Normal file
File diff suppressed because it is too large
Load diff
303
pkg/ml/convert.go
Normal file
303
pkg/ml/convert.go
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
loraARe = regexp.MustCompile(`\.lora_a$`)
|
||||
loraBRe = regexp.MustCompile(`\.lora_b$`)
|
||||
layerRe = regexp.MustCompile(`layers\.(\d+)`)
|
||||
moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`)
|
||||
)
|
||||
|
||||
// RenameMLXKey converts an MLX tensor key to PEFT format.
|
||||
func RenameMLXKey(mlxKey string) string {
|
||||
key := mlxKey
|
||||
key = loraARe.ReplaceAllString(key, ".lora_A.default.weight")
|
||||
key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight")
|
||||
key = "base_model.model." + key
|
||||
return key
|
||||
}
|
||||
|
||||
// SafetensorsHeader represents the header of a safetensors file.
|
||||
type SafetensorsHeader struct {
|
||||
Metadata map[string]string `json:"__metadata__,omitempty"`
|
||||
Tensors map[string]SafetensorsTensorInfo `json:"-"`
|
||||
}
|
||||
|
||||
// SafetensorsTensorInfo describes a tensor's dtype, shape, and data location.
|
||||
type SafetensorsTensorInfo struct {
|
||||
Dtype string `json:"dtype"`
|
||||
Shape []int `json:"shape"`
|
||||
DataOffsets [2]int `json:"data_offsets"`
|
||||
}
|
||||
|
||||
// ReadSafetensors reads a safetensors file and returns tensor info and raw data.
|
||||
func ReadSafetensors(path string) (map[string]SafetensorsTensorInfo, []byte, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
|
||||
if len(data) < 8 {
|
||||
return nil, nil, fmt.Errorf("file too small")
|
||||
}
|
||||
|
||||
headerSize := int(binary.LittleEndian.Uint64(data[:8]))
|
||||
if 8+headerSize > len(data) {
|
||||
return nil, nil, fmt.Errorf("invalid header size %d", headerSize)
|
||||
}
|
||||
|
||||
headerJSON := data[8 : 8+headerSize]
|
||||
tensorData := data[8+headerSize:]
|
||||
|
||||
var rawHeader map[string]json.RawMessage
|
||||
if err := json.Unmarshal(headerJSON, &rawHeader); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse header: %w", err)
|
||||
}
|
||||
|
||||
tensors := make(map[string]SafetensorsTensorInfo)
|
||||
for key, raw := range rawHeader {
|
||||
if key == "__metadata__" {
|
||||
continue
|
||||
}
|
||||
var info SafetensorsTensorInfo
|
||||
if err := json.Unmarshal(raw, &info); err != nil {
|
||||
return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err)
|
||||
}
|
||||
tensors[key] = info
|
||||
}
|
||||
|
||||
return tensors, tensorData, nil
|
||||
}
|
||||
|
||||
// GetTensorData extracts raw bytes for a tensor from the data section.
|
||||
func GetTensorData(info SafetensorsTensorInfo, allData []byte) []byte {
|
||||
return allData[info.DataOffsets[0]:info.DataOffsets[1]]
|
||||
}
|
||||
|
||||
// TransposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows).
|
||||
func TransposeFloat32(data []byte, rows, cols int) []byte {
|
||||
if len(data) != rows*cols*4 {
|
||||
return data
|
||||
}
|
||||
result := make([]byte, len(data))
|
||||
for r := range rows {
|
||||
for c := range cols {
|
||||
srcOff := (r*cols + c) * 4
|
||||
dstOff := (c*rows + r) * 4
|
||||
copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TransposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows).
|
||||
func TransposeFloat16(data []byte, rows, cols int) []byte {
|
||||
if len(data) != rows*cols*2 {
|
||||
return data
|
||||
}
|
||||
result := make([]byte, len(data))
|
||||
for r := range rows {
|
||||
for c := range cols {
|
||||
srcOff := (r*cols + c) * 2
|
||||
dstOff := (c*rows + r) * 2
|
||||
copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// TransposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows).
|
||||
func TransposeBFloat16(data []byte, rows, cols int) []byte {
|
||||
return TransposeFloat16(data, rows, cols)
|
||||
}
|
||||
|
||||
// WriteSafetensors writes tensors to a safetensors file.
|
||||
func WriteSafetensors(path string, tensors map[string]SafetensorsTensorInfo, tensorData map[string][]byte) error {
|
||||
keys := make([]string, 0, len(tensors))
|
||||
for k := range tensors {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
offset := 0
|
||||
updatedTensors := make(map[string]SafetensorsTensorInfo)
|
||||
for _, k := range keys {
|
||||
info := tensors[k]
|
||||
data := tensorData[k]
|
||||
info.DataOffsets = [2]int{offset, offset + len(data)}
|
||||
updatedTensors[k] = info
|
||||
offset += len(data)
|
||||
}
|
||||
|
||||
headerMap := make(map[string]interface{})
|
||||
for k, info := range updatedTensors {
|
||||
headerMap[k] = info
|
||||
}
|
||||
|
||||
headerJSON, err := json.Marshal(headerMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal header: %w", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
headerSizeBuf := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON)))
|
||||
|
||||
if _, err := f.Write(headerSizeBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(headerJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if _, err := f.Write(tensorData[k]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConvertMLXtoPEFT converts an MLX LoRA adapter to HuggingFace PEFT format.
|
||||
func ConvertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error {
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
tensors, tensorData, err := ReadSafetensors(safetensorsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read safetensors: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
|
||||
|
||||
peftTensors := make(map[string]SafetensorsTensorInfo)
|
||||
peftData := make(map[string][]byte)
|
||||
|
||||
for mlxKey, info := range tensors {
|
||||
peftKey := RenameMLXKey(mlxKey)
|
||||
data := GetTensorData(info, tensorData)
|
||||
|
||||
if len(info.Shape) == 2 {
|
||||
rows, cols := info.Shape[0], info.Shape[1]
|
||||
switch info.Dtype {
|
||||
case "F32":
|
||||
data = TransposeFloat32(data, rows, cols)
|
||||
case "F16":
|
||||
data = TransposeFloat16(data, rows, cols)
|
||||
case "BF16":
|
||||
data = TransposeBFloat16(data, rows, cols)
|
||||
}
|
||||
info.Shape = []int{cols, rows}
|
||||
}
|
||||
|
||||
peftTensors[peftKey] = info
|
||||
peftData[peftKey] = data
|
||||
}
|
||||
|
||||
outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors")
|
||||
if err := WriteSafetensors(outSafetensors, peftTensors, peftData); err != nil {
|
||||
return fmt.Errorf("write safetensors: %w", err)
|
||||
}
|
||||
|
||||
cfgData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
|
||||
var mlxConfig struct {
|
||||
LoraParameters struct {
|
||||
Rank int `json:"rank"`
|
||||
Scale float64 `json:"scale"`
|
||||
Dropout float64 `json:"dropout"`
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
rank := mlxConfig.LoraParameters.Rank
|
||||
if rank == 0 {
|
||||
rank = 8
|
||||
}
|
||||
scale := mlxConfig.LoraParameters.Scale
|
||||
if scale == 0 {
|
||||
scale = 20.0
|
||||
}
|
||||
|
||||
modules := make(map[string]bool)
|
||||
layers := make(map[int]bool)
|
||||
for k := range tensors {
|
||||
if m := moduleRe.FindStringSubmatch(k); m != nil {
|
||||
parts := strings.Split(m[1], ".")
|
||||
modules[parts[len(parts)-1]] = true
|
||||
}
|
||||
if m := layerRe.FindStringSubmatch(k); m != nil {
|
||||
n, _ := strconv.Atoi(m[1])
|
||||
layers[n] = true
|
||||
}
|
||||
}
|
||||
|
||||
sortedModules := make([]string, 0, len(modules))
|
||||
for m := range modules {
|
||||
sortedModules = append(sortedModules, m)
|
||||
}
|
||||
sort.Strings(sortedModules)
|
||||
|
||||
sortedLayers := make([]int, 0, len(layers))
|
||||
for l := range layers {
|
||||
sortedLayers = append(sortedLayers, l)
|
||||
}
|
||||
sort.Ints(sortedLayers)
|
||||
|
||||
peftConfig := map[string]interface{}{
|
||||
"auto_mapping": nil,
|
||||
"base_model_name_or_path": baseModelName,
|
||||
"bias": "none",
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": true,
|
||||
"init_lora_weights": true,
|
||||
"layers_pattern": nil,
|
||||
"layers_to_transform": sortedLayers,
|
||||
"lora_alpha": math.Round(scale * float64(rank)),
|
||||
"lora_dropout": mlxConfig.LoraParameters.Dropout,
|
||||
"modules_to_save": nil,
|
||||
"peft_type": "LORA",
|
||||
"r": rank,
|
||||
"revision": nil,
|
||||
"target_modules": sortedModules,
|
||||
"task_type": "CAUSAL_LM",
|
||||
}
|
||||
|
||||
cfgJSON, err := json.MarshalIndent(peftConfig, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal peft config: %w", err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil {
|
||||
return fmt.Errorf("write adapter_config.json: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("converted %d tensors, %d layers, target modules: %v",
|
||||
len(peftTensors), len(sortedLayers), sortedModules)
|
||||
|
||||
return nil
|
||||
}
|
||||
241
pkg/ml/db.go
Normal file
241
pkg/ml/db.go
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/marcboeker/go-duckdb"
|
||||
)
|
||||
|
||||
// DB wraps a DuckDB connection.
|
||||
type DB struct {
|
||||
conn *sql.DB
|
||||
path string
|
||||
}
|
||||
|
||||
// OpenDB opens a DuckDB database file in read-only mode to avoid locking
|
||||
// issues with the Python pipeline.
|
||||
func OpenDB(path string) (*DB, error) {
|
||||
conn, err := sql.Open("duckdb", path+"?access_mode=READ_ONLY")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
|
||||
}
|
||||
return &DB{conn: conn, path: path}, nil
|
||||
}
|
||||
|
||||
// OpenDBReadWrite opens a DuckDB database in read-write mode.
|
||||
func OpenDBReadWrite(path string) (*DB, error) {
|
||||
conn, err := sql.Open("duckdb", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open duckdb %s: %w", path, err)
|
||||
}
|
||||
if err := conn.Ping(); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("ping duckdb %s: %w", path, err)
|
||||
}
|
||||
return &DB{conn: conn, path: path}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (db *DB) Close() error {
|
||||
return db.conn.Close()
|
||||
}
|
||||
|
||||
// GoldenSetRow represents one row from the golden_set table.
|
||||
type GoldenSetRow struct {
|
||||
Idx int
|
||||
SeedID string
|
||||
Domain string
|
||||
Voice string
|
||||
Prompt string
|
||||
Response string
|
||||
GenTime float64
|
||||
CharCount int
|
||||
}
|
||||
|
||||
// ExpansionPromptRow represents one row from the expansion_prompts table.
|
||||
type ExpansionPromptRow struct {
|
||||
Idx int64
|
||||
SeedID string
|
||||
Region string
|
||||
Domain string
|
||||
Language string
|
||||
Prompt string
|
||||
PromptEn string
|
||||
Priority int
|
||||
Status string
|
||||
}
|
||||
|
||||
// QueryGoldenSet returns all golden set rows with responses >= minChars.
|
||||
func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) {
|
||||
rows, err := db.conn.Query(
|
||||
"SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+
|
||||
"FROM golden_set WHERE char_count >= ? ORDER BY idx",
|
||||
minChars,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query golden_set: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []GoldenSetRow
|
||||
for rows.Next() {
|
||||
var r GoldenSetRow
|
||||
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice,
|
||||
&r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil {
|
||||
return nil, fmt.Errorf("scan golden_set row: %w", err)
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// CountGoldenSet returns the total count of golden set rows.
|
||||
func (db *DB) CountGoldenSet() (int, error) {
|
||||
var count int
|
||||
err := db.conn.QueryRow("SELECT COUNT(*) FROM golden_set").Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("count golden_set: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// QueryExpansionPrompts returns expansion prompts filtered by status.
|
||||
func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) {
|
||||
query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " +
|
||||
"FROM expansion_prompts"
|
||||
var args []interface{}
|
||||
|
||||
if status != "" {
|
||||
query += " WHERE status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query += " ORDER BY priority, idx"
|
||||
|
||||
if limit > 0 {
|
||||
query += fmt.Sprintf(" LIMIT %d", limit)
|
||||
}
|
||||
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query expansion_prompts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var result []ExpansionPromptRow
|
||||
for rows.Next() {
|
||||
var r ExpansionPromptRow
|
||||
if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain,
|
||||
&r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil {
|
||||
return nil, fmt.Errorf("scan expansion_prompt row: %w", err)
|
||||
}
|
||||
result = append(result, r)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// CountExpansionPrompts returns counts by status.
|
||||
func (db *DB) CountExpansionPrompts() (total int, pending int, err error) {
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts").Scan(&total)
|
||||
if err != nil {
|
||||
return 0, 0, fmt.Errorf("count expansion_prompts: %w", err)
|
||||
}
|
||||
err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&pending)
|
||||
if err != nil {
|
||||
return total, 0, fmt.Errorf("count pending expansion_prompts: %w", err)
|
||||
}
|
||||
return total, pending, nil
|
||||
}
|
||||
|
||||
// UpdateExpansionStatus updates the status of an expansion prompt by idx.
|
||||
func (db *DB) UpdateExpansionStatus(idx int64, status string) error {
|
||||
_, err := db.conn.Exec("UPDATE expansion_prompts SET status = ? WHERE idx = ?", status, idx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update expansion_prompt %d: %w", idx, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryRows executes an arbitrary SQL query and returns results as maps.
|
||||
func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interface{}, error) {
|
||||
rows, err := db.conn.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("columns: %w", err)
|
||||
}
|
||||
|
||||
var result []map[string]interface{}
|
||||
for rows.Next() {
|
||||
values := make([]interface{}, len(cols))
|
||||
ptrs := make([]interface{}, len(cols))
|
||||
for i := range values {
|
||||
ptrs[i] = &values[i]
|
||||
}
|
||||
if err := rows.Scan(ptrs...); err != nil {
|
||||
return nil, fmt.Errorf("scan: %w", err)
|
||||
}
|
||||
row := make(map[string]interface{}, len(cols))
|
||||
for i, col := range cols {
|
||||
row[col] = values[i]
|
||||
}
|
||||
result = append(result, row)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
// EnsureScoringTables creates the scoring tables if they don't exist.
|
||||
func (db *DB) EnsureScoringTables() {
|
||||
db.conn.Exec(`CREATE TABLE IF NOT EXISTS checkpoint_scores (
|
||||
model TEXT, run_id TEXT, label TEXT, iteration INTEGER,
|
||||
correct INTEGER, total INTEGER, accuracy DOUBLE,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp,
|
||||
PRIMARY KEY (run_id, label)
|
||||
)`)
|
||||
db.conn.Exec(`CREATE TABLE IF NOT EXISTS probe_results (
|
||||
model TEXT, run_id TEXT, label TEXT, probe_id TEXT,
|
||||
passed BOOLEAN, response TEXT, iteration INTEGER,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp,
|
||||
PRIMARY KEY (run_id, label, probe_id)
|
||||
)`)
|
||||
db.conn.Exec(`CREATE TABLE IF NOT EXISTS scoring_results (
|
||||
model TEXT, prompt_id TEXT, suite TEXT,
|
||||
dimension TEXT, score DOUBLE,
|
||||
scored_at TIMESTAMP DEFAULT current_timestamp
|
||||
)`)
|
||||
}
|
||||
|
||||
// WriteScoringResult writes a single scoring dimension result to DuckDB.
|
||||
func (db *DB) WriteScoringResult(model, promptID, suite, dimension string, score float64) error {
|
||||
_, err := db.conn.Exec(
|
||||
`INSERT INTO scoring_results (model, prompt_id, suite, dimension, score) VALUES (?, ?, ?, ?, ?)`,
|
||||
model, promptID, suite, dimension, score,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// TableCounts returns row counts for all known tables.
|
||||
func (db *DB) TableCounts() (map[string]int, error) {
|
||||
tables := []string{"golden_set", "expansion_prompts", "seeds", "prompts",
|
||||
"training_examples", "gemini_responses", "benchmark_questions", "benchmark_results", "validations",
|
||||
"checkpoint_scores", "probe_results", "scoring_results"}
|
||||
|
||||
counts := make(map[string]int)
|
||||
for _, t := range tables {
|
||||
var count int
|
||||
err := db.conn.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", t)).Scan(&count)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
counts[t] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
153
pkg/ml/expand.go
Normal file
153
pkg/ml/expand.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExpandOutput is the JSONL output structure for expansion generation.
|
||||
type ExpandOutput struct {
|
||||
ID string `json:"id"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
Response string `json:"response"`
|
||||
Model string `json:"model"`
|
||||
ElapsedSeconds float64 `json:"elapsed_seconds"`
|
||||
Chars int `json:"chars"`
|
||||
}
|
||||
|
||||
// GetCompletedIDs queries InfluxDB for prompt IDs that have already been
|
||||
// processed in the expansion_gen measurement.
|
||||
func GetCompletedIDs(influx *InfluxClient) (map[string]bool, error) {
|
||||
rows, err := influx.QuerySQL("SELECT DISTINCT seed_id FROM expansion_gen")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query expansion_gen: %w", err)
|
||||
}
|
||||
|
||||
ids := make(map[string]bool, len(rows))
|
||||
for _, row := range rows {
|
||||
id := strVal(row, "seed_id")
|
||||
if id != "" {
|
||||
ids[id] = true
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// ExpandPrompts generates responses for expansion prompts using the given
|
||||
// backend and reports progress to InfluxDB. Already-completed prompts (per
|
||||
// InfluxDB) are skipped. API errors for individual prompts are logged and
|
||||
// skipped. InfluxDB reporting is best-effort.
|
||||
func ExpandPrompts(ctx context.Context, backend Backend, influx *InfluxClient, prompts []Response,
|
||||
modelName, worker, outputDir string, dryRun bool, limit int) error {
|
||||
|
||||
remaining := prompts
|
||||
|
||||
// Check InfluxDB for already-completed IDs.
|
||||
completed, err := GetCompletedIDs(influx)
|
||||
if err != nil {
|
||||
log.Printf("warning: could not check completed IDs: %v", err)
|
||||
} else {
|
||||
remaining = nil
|
||||
for _, p := range prompts {
|
||||
if !completed[p.ID] {
|
||||
remaining = append(remaining, p)
|
||||
}
|
||||
}
|
||||
|
||||
skipped := len(prompts) - len(remaining)
|
||||
if skipped > 0 {
|
||||
log.Printf("skipping %d already-completed prompts, %d remaining", skipped, len(remaining))
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 && limit < len(remaining) {
|
||||
remaining = remaining[:limit]
|
||||
}
|
||||
|
||||
if len(remaining) == 0 {
|
||||
log.Println("all prompts already completed, nothing to do")
|
||||
return nil
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
log.Printf("dry-run: would process %d prompts with model %s (worker: %s)", len(remaining), modelName, worker)
|
||||
for i, p := range remaining {
|
||||
if i >= 10 {
|
||||
log.Printf(" ... and %d more", len(remaining)-10)
|
||||
break
|
||||
}
|
||||
log.Printf(" %s (domain: %s)", p.ID, p.Domain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
outputPath := filepath.Join(outputDir, fmt.Sprintf("expand-%s.jsonl", worker))
|
||||
f, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open output file: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
total := len(remaining)
|
||||
completedCount := 0
|
||||
|
||||
for idx, p := range remaining {
|
||||
start := time.Now()
|
||||
response, err := backend.Generate(ctx, p.Prompt, GenOpts{Temperature: 0.7, MaxTokens: 2048})
|
||||
elapsed := time.Since(start).Seconds()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[%d/%d] id=%s ERROR: %v", idx+1, total, p.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
chars := len(response)
|
||||
completedCount++
|
||||
|
||||
out := ExpandOutput{
|
||||
ID: p.ID,
|
||||
Domain: p.Domain,
|
||||
Prompt: p.Prompt,
|
||||
Response: response,
|
||||
Model: modelName,
|
||||
ElapsedSeconds: elapsed,
|
||||
Chars: chars,
|
||||
}
|
||||
|
||||
line, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
log.Printf("[%d/%d] id=%s marshal error: %v", idx+1, total, p.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, err := f.Write(append(line, '\n')); err != nil {
|
||||
log.Printf("[%d/%d] id=%s write error: %v", idx+1, total, p.ID, err)
|
||||
continue
|
||||
}
|
||||
|
||||
genLine := fmt.Sprintf("expansion_gen,i=%d,w=%s,d=%s seed_id=\"%s\",gen_time=%f,chars=%di,model=\"%s\"",
|
||||
idx, EscapeLp(worker), EscapeLp(p.Domain),
|
||||
p.ID, elapsed, chars, modelName)
|
||||
|
||||
pct := float64(completedCount) / float64(total) * 100.0
|
||||
progressLine := fmt.Sprintf("expansion_progress,worker=%s completed=%di,target=%di,pct=%f",
|
||||
EscapeLp(worker), completedCount, total, pct)
|
||||
|
||||
if writeErr := influx.WriteLp([]string{genLine, progressLine}); writeErr != nil {
|
||||
log.Printf("[%d/%d] id=%s influx write error: %v", idx+1, total, p.ID, writeErr)
|
||||
}
|
||||
|
||||
log.Printf("[%d/%d] id=%s chars=%d time=%.1fs", idx+1, total, p.ID, chars, elapsed)
|
||||
}
|
||||
|
||||
log.Printf("expand complete: %d/%d prompts generated, output: %s", completedCount, total, outputPath)
|
||||
|
||||
return nil
|
||||
}
|
||||
112
pkg/ml/export.go
Normal file
112
pkg/ml/export.go
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChatMessage is a single message in the chat training format.
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// TrainingExample is a single training example in chat JSONL format.
|
||||
type TrainingExample struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
// ValidatePercentages checks that train+valid+test percentages sum to 100
|
||||
// and that none are negative.
|
||||
func ValidatePercentages(trainPct, validPct, testPct int) error {
|
||||
if trainPct < 0 || validPct < 0 || testPct < 0 {
|
||||
return fmt.Errorf("percentages must be non-negative: train=%d, valid=%d, test=%d", trainPct, validPct, testPct)
|
||||
}
|
||||
sum := trainPct + validPct + testPct
|
||||
if sum != 100 {
|
||||
return fmt.Errorf("percentages must sum to 100, got %d (train=%d + valid=%d + test=%d)", sum, trainPct, validPct, testPct)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilterResponses removes responses with empty content, "ERROR:" prefix,
|
||||
// or response length < 50 characters.
|
||||
func FilterResponses(responses []Response) []Response {
|
||||
var filtered []Response
|
||||
for _, r := range responses {
|
||||
if r.Response == "" {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(r.Response, "ERROR:") {
|
||||
continue
|
||||
}
|
||||
if len(r.Response) < 50 {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, r)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// SplitData shuffles responses with a deterministic seed and splits them
|
||||
// into train, valid, and test sets by the given percentages.
|
||||
func SplitData(responses []Response, trainPct, validPct, testPct int, seed int64) (train, valid, test []Response) {
|
||||
shuffled := make([]Response, len(responses))
|
||||
copy(shuffled, responses)
|
||||
|
||||
rng := rand.New(rand.NewSource(seed))
|
||||
rng.Shuffle(len(shuffled), func(i, j int) {
|
||||
shuffled[i], shuffled[j] = shuffled[j], shuffled[i]
|
||||
})
|
||||
|
||||
n := len(shuffled)
|
||||
trainN := n * trainPct / 100
|
||||
validN := n * validPct / 100
|
||||
_ = testPct
|
||||
|
||||
train = shuffled[:trainN]
|
||||
valid = shuffled[trainN : trainN+validN]
|
||||
test = shuffled[trainN+validN:]
|
||||
|
||||
return train, valid, test
|
||||
}
|
||||
|
||||
// WriteTrainingJSONL writes responses in chat JSONL format suitable for
|
||||
// MLX LoRA fine-tuning.
|
||||
func WriteTrainingJSONL(path string, responses []Response) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := bufio.NewWriter(f)
|
||||
defer w.Flush()
|
||||
|
||||
for _, r := range responses {
|
||||
example := TrainingExample{
|
||||
Messages: []ChatMessage{
|
||||
{Role: "user", Content: r.Prompt},
|
||||
{Role: "assistant", Content: r.Response},
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(example)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal example: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.Write(data); err != nil {
|
||||
return fmt.Errorf("write line: %w", err)
|
||||
}
|
||||
if _, err := w.WriteString("\n"); err != nil {
|
||||
return fmt.Errorf("write newline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
369
pkg/ml/gguf.go
Normal file
369
pkg/ml/gguf.go
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GGUF format constants.
|
||||
const (
|
||||
ggufMagic = 0x46554747 // "GGUF" little-endian
|
||||
ggufVersion = 3
|
||||
ggufAlignment = 32
|
||||
)
|
||||
|
||||
// GGUF metadata value types.
|
||||
const (
|
||||
ggufTypeUint32 = 4
|
||||
ggufTypeFloat32 = 6
|
||||
ggufTypeString = 8
|
||||
)
|
||||
|
||||
// GGML tensor data types.
|
||||
const (
|
||||
ggmlTypeF32 = 0
|
||||
ggmlTypeF16 = 1
|
||||
ggmlTypeBF16 = 30
|
||||
)
|
||||
|
||||
// ggufMetadata is a key-value pair in the GGUF header.
|
||||
type ggufMetadata struct {
|
||||
key string
|
||||
valueType uint32
|
||||
value interface{} // string, uint32, or float32
|
||||
}
|
||||
|
||||
// ggufTensor describes a tensor in the GGUF file.
|
||||
type ggufTensor struct {
|
||||
name string
|
||||
dims []uint64
|
||||
dtype uint32 // ggmlType*
|
||||
data []byte
|
||||
}
|
||||
|
||||
// gemma3ModuleMap maps HuggingFace module names to GGUF tensor names.
|
||||
var gemma3ModuleMap = map[string]string{
|
||||
"self_attn.q_proj": "attn_q",
|
||||
"self_attn.k_proj": "attn_k",
|
||||
"self_attn.v_proj": "attn_v",
|
||||
"self_attn.o_proj": "attn_output",
|
||||
"mlp.gate_proj": "ffn_gate",
|
||||
"mlp.up_proj": "ffn_up",
|
||||
"mlp.down_proj": "ffn_down",
|
||||
}
|
||||
|
||||
var mlxLoraKeyRe = regexp.MustCompile(`^model\.layers\.(\d+)\.(.*?)\.(lora_[ab])$`)
|
||||
|
||||
// MLXTensorToGGUF converts an MLX LoRA tensor name to GGUF LoRA tensor name.
|
||||
// Input: "model.layers.0.self_attn.q_proj.lora_a"
|
||||
// Output: "blk.0.attn_q.weight.lora_a"
|
||||
func MLXTensorToGGUF(mlxName string) (string, error) {
|
||||
m := mlxLoraKeyRe.FindStringSubmatch(mlxName)
|
||||
if m == nil {
|
||||
return "", fmt.Errorf("unrecognised MLX LoRA key: %s", mlxName)
|
||||
}
|
||||
|
||||
layerNum := m[1]
|
||||
module := m[2]
|
||||
loraSuffix := m[3]
|
||||
|
||||
ggufModule, ok := gemma3ModuleMap[module]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown module %q in %s", module, mlxName)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("blk.%s.%s.weight.%s", layerNum, ggufModule, loraSuffix), nil
|
||||
}
|
||||
|
||||
// SafetensorsDtypeToGGML maps safetensors dtype strings to GGML types.
|
||||
func SafetensorsDtypeToGGML(dtype string) (uint32, error) {
|
||||
switch dtype {
|
||||
case "F32":
|
||||
return ggmlTypeF32, nil
|
||||
case "F16":
|
||||
return ggmlTypeF16, nil
|
||||
case "BF16":
|
||||
return ggmlTypeBF16, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported dtype %q for GGUF", dtype)
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertMLXtoGGUFLoRA converts an MLX LoRA adapter to GGUF LoRA format.
|
||||
func ConvertMLXtoGGUFLoRA(safetensorsPath, configPath, outputPath, architecture string) error {
|
||||
cfgData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read config: %w", err)
|
||||
}
|
||||
|
||||
var mlxConfig struct {
|
||||
LoraParameters struct {
|
||||
Rank int `json:"rank"`
|
||||
Scale float64 `json:"scale"`
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
if err := json.Unmarshal(cfgData, &mlxConfig); err != nil {
|
||||
return fmt.Errorf("parse config: %w", err)
|
||||
}
|
||||
|
||||
rank := mlxConfig.LoraParameters.Rank
|
||||
if rank == 0 {
|
||||
rank = 8
|
||||
}
|
||||
scale := mlxConfig.LoraParameters.Scale
|
||||
if scale == 0 {
|
||||
scale = 20.0
|
||||
}
|
||||
loraAlpha := float32(math.Round(scale * float64(rank)))
|
||||
|
||||
tensors, tensorData, err := ReadSafetensors(safetensorsPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read safetensors: %w", err)
|
||||
}
|
||||
log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath)
|
||||
|
||||
var ggufTensors []ggufTensor
|
||||
for mlxKey, info := range tensors {
|
||||
ggufName, err := MLXTensorToGGUF(mlxKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ggmlType, err := SafetensorsDtypeToGGML(info.Dtype)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tensor %s: %w", mlxKey, err)
|
||||
}
|
||||
|
||||
data := GetTensorData(info, tensorData)
|
||||
|
||||
if len(info.Shape) == 2 {
|
||||
rows, cols := info.Shape[0], info.Shape[1]
|
||||
switch info.Dtype {
|
||||
case "F32":
|
||||
data = TransposeFloat32(data, rows, cols)
|
||||
case "F16":
|
||||
data = TransposeFloat16(data, rows, cols)
|
||||
case "BF16":
|
||||
data = TransposeBFloat16(data, rows, cols)
|
||||
}
|
||||
ggufTensors = append(ggufTensors, ggufTensor{
|
||||
name: ggufName,
|
||||
dims: []uint64{uint64(rows), uint64(cols)},
|
||||
dtype: ggmlType,
|
||||
data: data,
|
||||
})
|
||||
} else {
|
||||
dims := make([]uint64, len(info.Shape))
|
||||
for i, s := range info.Shape {
|
||||
dims[i] = uint64(s)
|
||||
}
|
||||
ggufTensors = append(ggufTensors, ggufTensor{
|
||||
name: ggufName,
|
||||
dims: dims,
|
||||
dtype: ggmlType,
|
||||
data: data,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
sort.Slice(ggufTensors, func(i, j int) bool {
|
||||
return ggufTensors[i].name < ggufTensors[j].name
|
||||
})
|
||||
|
||||
metadata := []ggufMetadata{
|
||||
{key: "general.type", valueType: ggufTypeString, value: "adapter"},
|
||||
{key: "general.architecture", valueType: ggufTypeString, value: architecture},
|
||||
{key: "adapter.type", valueType: ggufTypeString, value: "lora"},
|
||||
{key: "adapter.lora.alpha", valueType: ggufTypeFloat32, value: loraAlpha},
|
||||
}
|
||||
|
||||
if err := writeGGUF(outputPath, metadata, ggufTensors); err != nil {
|
||||
return fmt.Errorf("write GGUF: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("wrote GGUF LoRA: %s (%d tensors, alpha=%.0f)", outputPath, len(ggufTensors), loraAlpha)
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeGGUF writes a GGUF v3 file.
|
||||
func writeGGUF(path string, metadata []ggufMetadata, tensors []ggufTensor) error {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
w := &ggufWriter{f: f}
|
||||
|
||||
w.writeUint32(ggufMagic)
|
||||
w.writeUint32(ggufVersion)
|
||||
w.writeUint64(uint64(len(tensors)))
|
||||
w.writeUint64(uint64(len(metadata)))
|
||||
|
||||
for _, kv := range metadata {
|
||||
w.writeString(kv.key)
|
||||
w.writeUint32(kv.valueType)
|
||||
switch kv.valueType {
|
||||
case ggufTypeString:
|
||||
w.writeString(kv.value.(string))
|
||||
case ggufTypeUint32:
|
||||
w.writeUint32(kv.value.(uint32))
|
||||
case ggufTypeFloat32:
|
||||
w.writeFloat32(kv.value.(float32))
|
||||
}
|
||||
}
|
||||
|
||||
dataOffset := uint64(0)
|
||||
for _, t := range tensors {
|
||||
w.writeString(t.name)
|
||||
w.writeUint32(uint32(len(t.dims)))
|
||||
for _, d := range t.dims {
|
||||
w.writeUint64(d)
|
||||
}
|
||||
w.writeUint32(t.dtype)
|
||||
w.writeUint64(dataOffset)
|
||||
|
||||
dataOffset += uint64(len(t.data))
|
||||
if rem := dataOffset % ggufAlignment; rem != 0 {
|
||||
dataOffset += ggufAlignment - rem
|
||||
}
|
||||
}
|
||||
|
||||
pos := w.pos
|
||||
if rem := pos % ggufAlignment; rem != 0 {
|
||||
pad := ggufAlignment - rem
|
||||
w.writeBytes(make([]byte, pad))
|
||||
}
|
||||
|
||||
for _, t := range tensors {
|
||||
w.writeBytes(t.data)
|
||||
if rem := uint64(len(t.data)) % ggufAlignment; rem != 0 {
|
||||
w.writeBytes(make([]byte, ggufAlignment-rem))
|
||||
}
|
||||
}
|
||||
|
||||
return w.err
|
||||
}
|
||||
|
||||
// ggufWriter tracks position and accumulates errors.
|
||||
type ggufWriter struct {
|
||||
f *os.File
|
||||
pos uint64
|
||||
err error
|
||||
}
|
||||
|
||||
func (w *ggufWriter) writeBytes(b []byte) {
|
||||
if w.err != nil {
|
||||
return
|
||||
}
|
||||
n, err := w.f.Write(b)
|
||||
w.pos += uint64(n)
|
||||
if err != nil {
|
||||
w.err = err
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ggufWriter) writeUint32(v uint32) {
|
||||
b := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(b, v)
|
||||
w.writeBytes(b)
|
||||
}
|
||||
|
||||
func (w *ggufWriter) writeUint64(v uint64) {
|
||||
b := make([]byte, 8)
|
||||
binary.LittleEndian.PutUint64(b, v)
|
||||
w.writeBytes(b)
|
||||
}
|
||||
|
||||
func (w *ggufWriter) writeFloat32(v float32) {
|
||||
w.writeUint32(math.Float32bits(v))
|
||||
}
|
||||
|
||||
func (w *ggufWriter) writeString(s string) {
|
||||
w.writeUint64(uint64(len(s)))
|
||||
w.writeBytes([]byte(s))
|
||||
}
|
||||
|
||||
// DetectArchFromConfig tries to infer the model architecture from adapter_config.json.
|
||||
func DetectArchFromConfig(configPath string) string {
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return "gemma3"
|
||||
}
|
||||
var cfg struct {
|
||||
LoraParameters struct {
|
||||
Rank int `json:"rank"`
|
||||
} `json:"lora_parameters"`
|
||||
}
|
||||
json.Unmarshal(data, &cfg)
|
||||
return "gemma3"
|
||||
}
|
||||
|
||||
// ArchitectureGGUFMap maps model tags to GGUF architecture names.
|
||||
var ArchitectureGGUFMap = map[string]string{
|
||||
"gemma-3-1b": "gemma3",
|
||||
"gemma-3-4b": "gemma3",
|
||||
"gemma-3-12b": "gemma3",
|
||||
"gemma-3-27b": "gemma3",
|
||||
}
|
||||
|
||||
// ModelTagToGGUFArch returns the GGUF architecture for a model tag.
|
||||
func ModelTagToGGUFArch(modelTag string) string {
|
||||
if arch, ok := ArchitectureGGUFMap[modelTag]; ok {
|
||||
return arch
|
||||
}
|
||||
return "gemma3"
|
||||
}
|
||||
|
||||
// GGUFModelBlobPath returns the path to the GGUF model blob in Ollama's store.
|
||||
func GGUFModelBlobPath(ollamaModelsDir, modelName string) (string, error) {
|
||||
parts := strings.SplitN(modelName, ":", 2)
|
||||
family := parts[0]
|
||||
tag := "latest"
|
||||
if len(parts) > 1 {
|
||||
tag = parts[1]
|
||||
}
|
||||
|
||||
manifestPath := fmt.Sprintf("%s/manifests/registry.ollama.ai/library/%s/%s", ollamaModelsDir, family, tag)
|
||||
data, err := os.ReadFile(manifestPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read manifest %s: %w", manifestPath, err)
|
||||
}
|
||||
|
||||
var manifest struct {
|
||||
Layers []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &manifest); err != nil {
|
||||
return "", fmt.Errorf("parse manifest: %w", err)
|
||||
}
|
||||
|
||||
for _, layer := range manifest.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
blobName := strings.Replace(layer.Digest, ":", "-", 1)
|
||||
return fmt.Sprintf("%s/blobs/%s", ollamaModelsDir, blobName), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no model layer found in manifest for %s", modelName)
|
||||
}
|
||||
|
||||
// ParseLayerFromTensorName extracts the layer number from a GGUF tensor name.
|
||||
func ParseLayerFromTensorName(name string) (int, error) {
|
||||
re := regexp.MustCompile(`blk\.(\d+)\.`)
|
||||
m := re.FindStringSubmatch(name)
|
||||
if m == nil {
|
||||
return 0, fmt.Errorf("no layer number in %s", name)
|
||||
}
|
||||
return strconv.Atoi(m[1])
|
||||
}
|
||||
132
pkg/ml/influx.go
Normal file
132
pkg/ml/influx.go
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InfluxClient talks to an InfluxDB v3 instance.
|
||||
type InfluxClient struct {
|
||||
url string
|
||||
db string
|
||||
token string
|
||||
}
|
||||
|
||||
// NewInfluxClient creates an InfluxClient for the given URL and database.
|
||||
// Reads token from INFLUX_TOKEN env var first, then ~/.influx_token file.
|
||||
// If url is empty, defaults to "http://10.69.69.165:8181".
|
||||
// If db is empty, defaults to "training".
|
||||
func NewInfluxClient(url, db string) *InfluxClient {
|
||||
if url == "" {
|
||||
url = "http://10.69.69.165:8181"
|
||||
}
|
||||
if db == "" {
|
||||
db = "training"
|
||||
}
|
||||
|
||||
token := os.Getenv("INFLUX_TOKEN")
|
||||
if token == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil {
|
||||
data, err := os.ReadFile(filepath.Join(home, ".influx_token"))
|
||||
if err == nil {
|
||||
token = strings.TrimSpace(string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &InfluxClient{
|
||||
url: url,
|
||||
db: db,
|
||||
token: token,
|
||||
}
|
||||
}
|
||||
|
||||
// WriteLp writes line protocol data to InfluxDB.
|
||||
func (c *InfluxClient) WriteLp(lines []string) error {
|
||||
body := strings.Join(lines, "\n")
|
||||
|
||||
url := fmt.Sprintf("%s/api/v3/write_lp?db=%s", c.url, c.db)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("create write request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("write failed %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QuerySQL runs a SQL query against InfluxDB and returns the result rows.
|
||||
func (c *InfluxClient) QuerySQL(sql string) ([]map[string]interface{}, error) {
|
||||
reqBody := map[string]string{
|
||||
"db": c.db,
|
||||
"q": sql,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal query request: %w", err)
|
||||
}
|
||||
|
||||
url := c.url + "/api/v3/query_sql"
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create query request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read query response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("query failed %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var rows []map[string]interface{}
|
||||
if err := json.Unmarshal(respBody, &rows); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal query response: %w", err)
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// EscapeLp escapes spaces, commas, and equals signs for InfluxDB line protocol
|
||||
// tag values.
|
||||
func EscapeLp(s string) string {
|
||||
s = strings.ReplaceAll(s, `,`, `\,`)
|
||||
s = strings.ReplaceAll(s, `=`, `\=`)
|
||||
s = strings.ReplaceAll(s, ` `, `\ `)
|
||||
return s
|
||||
}
|
||||
152
pkg/ml/ollama.go
Normal file
152
pkg/ml/ollama.go
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OllamaBaseModelMap maps model tags to Ollama model names.
|
||||
var OllamaBaseModelMap = map[string]string{
|
||||
"gemma-3-1b": "gemma3:1b",
|
||||
"gemma-3-4b": "gemma3:4b",
|
||||
"gemma-3-12b": "gemma3:12b",
|
||||
"gemma-3-27b": "gemma3:27b",
|
||||
}
|
||||
|
||||
// HFBaseModelMap maps model tags to HuggingFace model IDs.
|
||||
var HFBaseModelMap = map[string]string{
|
||||
"gemma-3-1b": "google/gemma-3-1b-it",
|
||||
"gemma-3-4b": "google/gemma-3-4b-it",
|
||||
"gemma-3-12b": "google/gemma-3-12b-it",
|
||||
"gemma-3-27b": "google/gemma-3-27b-it",
|
||||
}
|
||||
|
||||
// ollamaUploadBlob uploads a local file to Ollama's blob store.
|
||||
// Returns the sha256 digest string (e.g. "sha256:abc123...").
|
||||
func ollamaUploadBlob(ollamaURL, filePath string) (string, error) {
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read %s: %w", filePath, err)
|
||||
}
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
digest := "sha256:" + hex.EncodeToString(hash[:])
|
||||
|
||||
headReq, _ := http.NewRequest(http.MethodHead, ollamaURL+"/api/blobs/"+digest, nil)
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
headResp, err := client.Do(headReq)
|
||||
if err == nil && headResp.StatusCode == http.StatusOK {
|
||||
headResp.Body.Close()
|
||||
return digest, nil
|
||||
}
|
||||
if headResp != nil {
|
||||
headResp.Body.Close()
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, ollamaURL+"/api/blobs/"+digest, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("blob request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/octet-stream")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("blob upload: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return "", fmt.Errorf("blob upload HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// OllamaCreateModel creates a temporary Ollama model with a LoRA adapter.
|
||||
// peftDir is a local directory containing adapter_model.safetensors and adapter_config.json.
|
||||
func OllamaCreateModel(ollamaURL, modelName, baseModel, peftDir string) error {
|
||||
sfPath := peftDir + "/adapter_model.safetensors"
|
||||
cfgPath := peftDir + "/adapter_config.json"
|
||||
|
||||
sfDigest, err := ollamaUploadBlob(ollamaURL, sfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload adapter safetensors: %w", err)
|
||||
}
|
||||
|
||||
cfgDigest, err := ollamaUploadBlob(ollamaURL, cfgPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("upload adapter config: %w", err)
|
||||
}
|
||||
|
||||
reqBody, _ := json.Marshal(map[string]interface{}{
|
||||
"model": modelName,
|
||||
"from": baseModel,
|
||||
"adapters": map[string]string{
|
||||
"adapter_model.safetensors": sfDigest,
|
||||
"adapter_config.json": cfgDigest,
|
||||
},
|
||||
})
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Minute}
|
||||
resp, err := client.Post(ollamaURL+"/api/create", "application/json", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama create: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(resp.Body)
|
||||
for decoder.More() {
|
||||
var status struct {
|
||||
Status string `json:"status"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
if err := decoder.Decode(&status); err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return fmt.Errorf("ollama create decode: %w", err)
|
||||
}
|
||||
if status.Error != "" {
|
||||
return fmt.Errorf("ollama create: %s", status.Error)
|
||||
}
|
||||
if status.Status == "success" {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("ollama create: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OllamaDeleteModel removes a temporary Ollama model.
|
||||
func OllamaDeleteModel(ollamaURL, modelName string) error {
|
||||
body, _ := json.Marshal(map[string]string{"model": modelName})
|
||||
|
||||
req, err := http.NewRequest(http.MethodDelete, ollamaURL+"/api/delete", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama delete request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ollama delete: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("ollama delete %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
137
pkg/ml/parquet.go
Normal file
137
pkg/ml/parquet.go
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/parquet-go/parquet-go"
|
||||
)
|
||||
|
||||
// ParquetRow is the schema for exported Parquet files.
|
||||
type ParquetRow struct {
|
||||
Prompt string `parquet:"prompt"`
|
||||
Response string `parquet:"response"`
|
||||
System string `parquet:"system"`
|
||||
Messages string `parquet:"messages"`
|
||||
}
|
||||
|
||||
// ExportParquet reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl)
|
||||
// from trainingDir and writes Parquet files with snappy compression to outputDir.
|
||||
// Returns total rows exported.
|
||||
func ExportParquet(trainingDir, outputDir string) (int, error) {
|
||||
if outputDir == "" {
|
||||
outputDir = filepath.Join(trainingDir, "parquet")
|
||||
}
|
||||
if err := os.MkdirAll(outputDir, 0755); err != nil {
|
||||
return 0, fmt.Errorf("create output dir: %w", err)
|
||||
}
|
||||
|
||||
total := 0
|
||||
for _, split := range []string{"train", "valid", "test"} {
|
||||
jsonlPath := filepath.Join(trainingDir, split+".jsonl")
|
||||
if _, err := os.Stat(jsonlPath); os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := ExportSplitParquet(jsonlPath, outputDir, split)
|
||||
if err != nil {
|
||||
return total, fmt.Errorf("export %s: %w", split, err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// ExportSplitParquet reads a chat JSONL file and writes a Parquet file for the
|
||||
// given split name. Returns the number of rows written.
|
||||
func ExportSplitParquet(jsonlPath, outputDir, split string) (int, error) {
|
||||
f, err := os.Open(jsonlPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("open %s: %w", jsonlPath, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var rows []ParquetRow
|
||||
scanner := bufio.NewScanner(f)
|
||||
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
|
||||
|
||||
for scanner.Scan() {
|
||||
text := strings.TrimSpace(scanner.Text())
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var data struct {
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(text), &data); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var prompt, response, system string
|
||||
for _, m := range data.Messages {
|
||||
switch m.Role {
|
||||
case "user":
|
||||
if prompt == "" {
|
||||
prompt = m.Content
|
||||
}
|
||||
case "assistant":
|
||||
if response == "" {
|
||||
response = m.Content
|
||||
}
|
||||
case "system":
|
||||
if system == "" {
|
||||
system = m.Content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msgsJSON, _ := json.Marshal(data.Messages)
|
||||
rows = append(rows, ParquetRow{
|
||||
Prompt: prompt,
|
||||
Response: response,
|
||||
System: system,
|
||||
Messages: string(msgsJSON),
|
||||
})
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return 0, fmt.Errorf("scan %s: %w", jsonlPath, err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
outPath := filepath.Join(outputDir, split+".parquet")
|
||||
|
||||
out, err := os.Create(outPath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("create %s: %w", outPath, err)
|
||||
}
|
||||
|
||||
writer := parquet.NewGenericWriter[ParquetRow](out,
|
||||
parquet.Compression(&parquet.Snappy),
|
||||
)
|
||||
|
||||
if _, err := writer.Write(rows); err != nil {
|
||||
out.Close()
|
||||
return 0, fmt.Errorf("write parquet rows: %w", err)
|
||||
}
|
||||
|
||||
if err := writer.Close(); err != nil {
|
||||
out.Close()
|
||||
return 0, fmt.Errorf("close parquet writer: %w", err)
|
||||
}
|
||||
|
||||
if err := out.Close(); err != nil {
|
||||
return 0, fmt.Errorf("close file: %w", err)
|
||||
}
|
||||
|
||||
return len(rows), nil
|
||||
}
|
||||
212
pkg/ml/status.go
Normal file
212
pkg/ml/status.go
Normal file
|
|
@ -0,0 +1,212 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// trainingRow holds deduplicated training status + loss for a single model.
|
||||
type trainingRow struct {
|
||||
model string
|
||||
status string
|
||||
iteration int
|
||||
totalIters int
|
||||
pct float64
|
||||
loss float64
|
||||
hasLoss bool
|
||||
}
|
||||
|
||||
// genRow holds deduplicated generation progress for a single worker.
|
||||
type genRow struct {
|
||||
worker string
|
||||
completed int
|
||||
target int
|
||||
pct float64
|
||||
}
|
||||
|
||||
// PrintStatus queries InfluxDB for training and generation progress and writes
|
||||
// a formatted summary to w.
|
||||
func PrintStatus(influx *InfluxClient, w io.Writer) error {
|
||||
statusRows, err := influx.QuerySQL(
|
||||
"SELECT model, run_id, status, iteration, total_iters, pct FROM training_status ORDER BY time DESC LIMIT 10",
|
||||
)
|
||||
if err != nil {
|
||||
statusRows = nil
|
||||
}
|
||||
|
||||
lossRows, err := influx.QuerySQL(
|
||||
"SELECT model, loss_type, loss, iteration, tokens_per_sec FROM training_loss WHERE loss_type = 'train' ORDER BY time DESC LIMIT 10",
|
||||
)
|
||||
if err != nil {
|
||||
lossRows = nil
|
||||
}
|
||||
|
||||
goldenRows, err := influx.QuerySQL(
|
||||
"SELECT worker, completed, target, pct FROM golden_gen_progress ORDER BY time DESC LIMIT 5",
|
||||
)
|
||||
if err != nil {
|
||||
goldenRows = nil
|
||||
}
|
||||
|
||||
expansionRows, err := influx.QuerySQL(
|
||||
"SELECT worker, completed, target, pct FROM expansion_progress ORDER BY time DESC LIMIT 5",
|
||||
)
|
||||
if err != nil {
|
||||
expansionRows = nil
|
||||
}
|
||||
|
||||
training := dedupeTraining(statusRows, lossRows)
|
||||
golden := dedupeGeneration(goldenRows)
|
||||
expansion := dedupeGeneration(expansionRows)
|
||||
|
||||
fmt.Fprintln(w, "Training:")
|
||||
if len(training) == 0 {
|
||||
fmt.Fprintln(w, " (no data)")
|
||||
} else {
|
||||
for _, tr := range training {
|
||||
progress := fmt.Sprintf("%d/%d", tr.iteration, tr.totalIters)
|
||||
pct := fmt.Sprintf("%.1f%%", tr.pct)
|
||||
if tr.hasLoss {
|
||||
fmt.Fprintf(w, " %-13s %-9s %9s %7s loss=%.3f\n",
|
||||
tr.model, tr.status, progress, pct, tr.loss)
|
||||
} else {
|
||||
fmt.Fprintf(w, " %-13s %-9s %9s %7s\n",
|
||||
tr.model, tr.status, progress, pct)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "Generation:")
|
||||
|
||||
hasGenData := false
|
||||
|
||||
if len(golden) > 0 {
|
||||
hasGenData = true
|
||||
for _, g := range golden {
|
||||
progress := fmt.Sprintf("%d/%d", g.completed, g.target)
|
||||
pct := fmt.Sprintf("%.1f%%", g.pct)
|
||||
fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "golden", progress, pct, g.worker)
|
||||
}
|
||||
}
|
||||
|
||||
if len(expansion) > 0 {
|
||||
hasGenData = true
|
||||
for _, g := range expansion {
|
||||
progress := fmt.Sprintf("%d/%d", g.completed, g.target)
|
||||
pct := fmt.Sprintf("%.1f%%", g.pct)
|
||||
fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "expansion", progress, pct, g.worker)
|
||||
}
|
||||
}
|
||||
|
||||
if !hasGenData {
|
||||
fmt.Fprintln(w, " (no data)")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dedupeTraining merges training status and loss rows, keeping only the first
|
||||
// (latest) row per model.
|
||||
func dedupeTraining(statusRows, lossRows []map[string]interface{}) []trainingRow {
|
||||
lossMap := make(map[string]float64)
|
||||
lossSeenMap := make(map[string]bool)
|
||||
for _, row := range lossRows {
|
||||
model := strVal(row, "model")
|
||||
if model == "" || lossSeenMap[model] {
|
||||
continue
|
||||
}
|
||||
lossSeenMap[model] = true
|
||||
lossMap[model] = floatVal(row, "loss")
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
var rows []trainingRow
|
||||
for _, row := range statusRows {
|
||||
model := strVal(row, "model")
|
||||
if model == "" || seen[model] {
|
||||
continue
|
||||
}
|
||||
seen[model] = true
|
||||
|
||||
tr := trainingRow{
|
||||
model: model,
|
||||
status: strVal(row, "status"),
|
||||
iteration: intVal(row, "iteration"),
|
||||
totalIters: intVal(row, "total_iters"),
|
||||
pct: floatVal(row, "pct"),
|
||||
}
|
||||
|
||||
if loss, ok := lossMap[model]; ok {
|
||||
tr.loss = loss
|
||||
tr.hasLoss = true
|
||||
}
|
||||
|
||||
rows = append(rows, tr)
|
||||
}
|
||||
|
||||
sort.Slice(rows, func(i, j int) bool {
|
||||
return rows[i].model < rows[j].model
|
||||
})
|
||||
|
||||
return rows
|
||||
}
|
||||
|
||||
// dedupeGeneration deduplicates generation progress rows by worker.
|
||||
func dedupeGeneration(rows []map[string]interface{}) []genRow {
|
||||
seen := make(map[string]bool)
|
||||
var result []genRow
|
||||
for _, row := range rows {
|
||||
worker := strVal(row, "worker")
|
||||
if worker == "" || seen[worker] {
|
||||
continue
|
||||
}
|
||||
seen[worker] = true
|
||||
|
||||
result = append(result, genRow{
|
||||
worker: worker,
|
||||
completed: intVal(row, "completed"),
|
||||
target: intVal(row, "target"),
|
||||
pct: floatVal(row, "pct"),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(result, func(i, j int) bool {
|
||||
return result[i].worker < result[j].worker
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// strVal extracts a string value from a row map.
|
||||
func strVal(row map[string]interface{}, key string) string {
|
||||
v, ok := row[key]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
s, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// floatVal extracts a float64 value from a row map.
|
||||
func floatVal(row map[string]interface{}, key string) float64 {
|
||||
v, ok := row[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
f, ok := v.(float64)
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
// intVal extracts an integer value from a row map. InfluxDB JSON returns all
|
||||
// numbers as float64, so this truncates to int.
|
||||
func intVal(row map[string]interface{}, key string) int {
|
||||
return int(floatVal(row, key))
|
||||
}
|
||||
403
pkg/ml/worker.go
Normal file
403
pkg/ml/worker.go
Normal file
|
|
@ -0,0 +1,403 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WorkerConfig holds the worker's runtime configuration.
|
||||
type WorkerConfig struct {
|
||||
APIBase string
|
||||
WorkerID string
|
||||
Name string
|
||||
APIKey string
|
||||
GPUType string
|
||||
VRAMGb int
|
||||
Languages []string
|
||||
Models []string
|
||||
InferURL string
|
||||
TaskType string
|
||||
BatchSize int
|
||||
PollInterval time.Duration
|
||||
OneShot bool
|
||||
DryRun bool
|
||||
}
|
||||
|
||||
// APITask represents a task from the LEM API.
|
||||
type APITask struct {
|
||||
ID int `json:"id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Status string `json:"status"`
|
||||
Language string `json:"language"`
|
||||
Domain string `json:"domain"`
|
||||
ModelName string `json:"model_name"`
|
||||
PromptID string `json:"prompt_id"`
|
||||
PromptText string `json:"prompt_text"`
|
||||
Config *struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
} `json:"config"`
|
||||
Priority int `json:"priority"`
|
||||
}
|
||||
|
||||
// RunWorkerLoop is the main worker loop that polls for tasks and processes them.
|
||||
func RunWorkerLoop(cfg *WorkerConfig) {
|
||||
log.Printf("LEM Worker starting")
|
||||
log.Printf(" ID: %s", cfg.WorkerID)
|
||||
log.Printf(" Name: %s", cfg.Name)
|
||||
log.Printf(" API: %s", cfg.APIBase)
|
||||
log.Printf(" Infer: %s", cfg.InferURL)
|
||||
log.Printf(" GPU: %s (%d GB)", cfg.GPUType, cfg.VRAMGb)
|
||||
log.Printf(" Langs: %v", cfg.Languages)
|
||||
log.Printf(" Models: %v", cfg.Models)
|
||||
log.Printf(" Batch: %d", cfg.BatchSize)
|
||||
log.Printf(" Dry-run: %v", cfg.DryRun)
|
||||
|
||||
if err := workerRegister(cfg); err != nil {
|
||||
log.Fatalf("Registration failed: %v", err)
|
||||
}
|
||||
log.Println("Registered with LEM API")
|
||||
|
||||
for {
|
||||
processed := workerPoll(cfg)
|
||||
|
||||
if cfg.OneShot {
|
||||
log.Printf("One-shot mode: processed %d tasks, exiting", processed)
|
||||
return
|
||||
}
|
||||
|
||||
if processed == 0 {
|
||||
log.Printf("No tasks available, sleeping %v", cfg.PollInterval)
|
||||
time.Sleep(cfg.PollInterval)
|
||||
}
|
||||
|
||||
workerHeartbeat(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func workerRegister(cfg *WorkerConfig) error {
|
||||
body := map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
"name": cfg.Name,
|
||||
"version": "0.1.0",
|
||||
"os": runtime.GOOS,
|
||||
"arch": runtime.GOARCH,
|
||||
}
|
||||
if cfg.GPUType != "" {
|
||||
body["gpu_type"] = cfg.GPUType
|
||||
}
|
||||
if cfg.VRAMGb > 0 {
|
||||
body["vram_gb"] = cfg.VRAMGb
|
||||
}
|
||||
if len(cfg.Languages) > 0 {
|
||||
body["languages"] = cfg.Languages
|
||||
}
|
||||
if len(cfg.Models) > 0 {
|
||||
body["supported_models"] = cfg.Models
|
||||
}
|
||||
|
||||
_, err := apiPost(cfg, "/api/lem/workers/register", body)
|
||||
return err
|
||||
}
|
||||
|
||||
func workerHeartbeat(cfg *WorkerConfig) {
|
||||
body := map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
}
|
||||
apiPost(cfg, "/api/lem/workers/heartbeat", body)
|
||||
}
|
||||
|
||||
func workerPoll(cfg *WorkerConfig) int {
|
||||
url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.WorkerID, cfg.BatchSize)
|
||||
if cfg.TaskType != "" {
|
||||
url += "&type=" + cfg.TaskType
|
||||
}
|
||||
|
||||
resp, err := apiGet(cfg, url)
|
||||
if err != nil {
|
||||
log.Printf("Error fetching tasks: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Tasks []APITask `json:"tasks"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
log.Printf("Error parsing tasks: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
if result.Count == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
log.Printf("Got %d tasks", result.Count)
|
||||
processed := 0
|
||||
|
||||
for _, task := range result.Tasks {
|
||||
if err := workerProcessTask(cfg, task); err != nil {
|
||||
log.Printf("Task %d failed: %v", task.ID, err)
|
||||
apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
processed++
|
||||
}
|
||||
|
||||
return processed
|
||||
}
|
||||
|
||||
func workerProcessTask(cfg *WorkerConfig, task APITask) error {
|
||||
log.Printf("Processing task %d: %s [%s/%s] %d chars prompt",
|
||||
task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText))
|
||||
|
||||
_, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("claim: %w", err)
|
||||
}
|
||||
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
"status": "in_progress",
|
||||
})
|
||||
|
||||
if cfg.DryRun {
|
||||
log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText)
|
||||
return nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
response, err := workerInfer(cfg, task)
|
||||
genTime := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
"status": "abandoned",
|
||||
})
|
||||
return fmt.Errorf("inference: %w", err)
|
||||
}
|
||||
|
||||
modelUsed := task.ModelName
|
||||
if modelUsed == "" {
|
||||
modelUsed = "default"
|
||||
}
|
||||
|
||||
_, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{
|
||||
"worker_id": cfg.WorkerID,
|
||||
"response_text": response,
|
||||
"model_used": modelUsed,
|
||||
"gen_time_ms": int(genTime.Milliseconds()),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("submit result: %w", err)
|
||||
}
|
||||
|
||||
log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond))
|
||||
return nil
|
||||
}
|
||||
|
||||
func workerInfer(cfg *WorkerConfig, task APITask) (string, error) {
|
||||
messages := []map[string]string{
|
||||
{"role": "user", "content": task.PromptText},
|
||||
}
|
||||
|
||||
temp := 0.7
|
||||
maxTokens := 2048
|
||||
if task.Config != nil {
|
||||
if task.Config.Temperature > 0 {
|
||||
temp = task.Config.Temperature
|
||||
}
|
||||
if task.Config.MaxTokens > 0 {
|
||||
maxTokens = task.Config.MaxTokens
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"model": task.ModelName,
|
||||
"messages": messages,
|
||||
"temperature": temp,
|
||||
"max_tokens": maxTokens,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", cfg.InferURL+"/v1/chat/completions", bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 5 * time.Minute}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("inference request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
var chatResp struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return "", fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(chatResp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no choices in response")
|
||||
}
|
||||
|
||||
content := chatResp.Choices[0].Message.Content
|
||||
if len(content) < 10 {
|
||||
return "", fmt.Errorf("response too short: %d chars", len(content))
|
||||
}
|
||||
|
||||
return content, nil
|
||||
}
|
||||
|
||||
// HTTP helpers for the LEM API.
|
||||
|
||||
func apiGet(cfg *WorkerConfig, path string) ([]byte, error) {
|
||||
req, err := http.NewRequest("GET", cfg.APIBase+path, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func apiPost(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "POST", path, data)
|
||||
}
|
||||
|
||||
func apiPatch(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "PATCH", path, data)
|
||||
}
|
||||
|
||||
func apiDelete(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) {
|
||||
return apiRequest(cfg, "DELETE", path, data)
|
||||
}
|
||||
|
||||
func apiRequest(cfg *WorkerConfig, method, path string, data map[string]interface{}) ([]byte, error) {
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, cfg.APIBase+path, bytes.NewReader(jsonData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+cfg.APIKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{Timeout: 15 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200))
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// MachineID returns the machine ID from /etc/machine-id or hostname fallback.
|
||||
func MachineID() string {
|
||||
if data, err := os.ReadFile("/etc/machine-id"); err == nil {
|
||||
id := string(bytes.TrimSpace(data))
|
||||
if len(id) > 0 {
|
||||
return id
|
||||
}
|
||||
}
|
||||
h, _ := os.Hostname()
|
||||
return h
|
||||
}
|
||||
|
||||
// Hostname returns the system hostname.
|
||||
func Hostname() string {
|
||||
h, _ := os.Hostname()
|
||||
return h
|
||||
}
|
||||
|
||||
// ReadKeyFile reads the LEM API key from ~/.config/lem/api_key.
|
||||
func ReadKeyFile() string {
|
||||
home, _ := os.UserHomeDir()
|
||||
path := filepath.Join(home, ".config", "lem", "api_key")
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(bytes.TrimSpace(data))
|
||||
}
|
||||
|
||||
// SplitComma splits a comma-separated string into trimmed parts.
|
||||
func SplitComma(s string) []string {
|
||||
var result []string
|
||||
for _, part := range bytes.Split([]byte(s), []byte(",")) {
|
||||
trimmed := bytes.TrimSpace(part)
|
||||
if len(trimmed) > 0 {
|
||||
result = append(result, string(trimmed))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func truncStr(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue