diff --git a/go.mod b/go.mod index ecb6d9c..4d23859 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,8 @@ require ( github.com/Snider/Enchantrix v0.0.2 // indirect github.com/TwiN/go-color v1.4.1 // indirect github.com/adrg/xdg v0.5.3 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/apache/arrow-go/v18 v18.1.0 // indirect github.com/aws/aws-sdk-go-v2 v1.41.1 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.17 // indirect @@ -71,9 +73,11 @@ require ( github.com/go-openapi/jsonpointer v0.22.4 // indirect github.com/go-openapi/swag/jsonname v0.25.4 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/goccy/go-json v0.10.5 // indirect github.com/godbus/dbus/v5 v5.2.2 // indirect github.com/gofrs/flock v0.12.1 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/google/flatbuffers v25.1.24+incompatible // indirect github.com/google/go-github/v39 v39.2.0 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/jsonschema-go v0.4.2 // indirect @@ -85,11 +89,13 @@ require ( github.com/jchv/go-winloader v0.0.0-20250406163304-c1995be93bd1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/kevinburke/ssh_config v1.4.0 // indirect + github.com/klauspost/compress v1.18.3 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leaanthony/go-ansi-parser v1.6.1 // indirect github.com/leaanthony/u v1.1.1 // indirect github.com/lmittmann/tint v1.1.2 // indirect github.com/mailru/easyjson v0.9.1 // indirect + github.com/marcboeker/go-duckdb v1.8.5 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect @@ -97,8 +103,12 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 // indirect github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 // indirect + github.com/parquet-go/bitpack v1.0.0 // indirect + github.com/parquet-go/jsonlite v1.0.0 // indirect + github.com/parquet-go/parquet-go v0.27.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pjbgf/sha1cd v0.5.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -119,9 +129,9 @@ require ( github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/twpayne/go-geom v1.6.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect github.com/ulikunitz/xz v0.5.15 // indirect - github.com/unpoller/unifi/v5 v5.17.0 // indirect github.com/wI2L/jsondiff v0.7.0 // indirect github.com/wailsapp/go-webview2 v1.0.23 // indirect github.com/wailsapp/wails/v3 v3.0.0-alpha.64 // indirect @@ -130,10 +140,14 @@ require ( github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yargevad/filepathx v1.0.0 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect + github.com/zeebo/xxh3 v1.1.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect + golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 // indirect + golang.org/x/tools v0.41.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect google.golang.org/grpc v1.76.0 // indirect google.golang.org/protobuf v1.36.10 // indirect diff --git a/go.sum b/go.sum index 8b441bb..0ed29f0 100644 --- a/go.sum +++ b/go.sum @@ -18,12 +18,17 @@ github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBi github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= github.com/Snider/Borg v0.2.0 h1:iCyDhY4WTXi39+FexRwXbn2YpZ2U9FUXVXDZk9xRCXQ= github.com/Snider/Borg v0.2.0/go.mod h1:TqlKnfRo9okioHbgrZPfWjQsztBV0Nfskz4Om1/vdMY= +github.com/Snider/Enchantrix v0.0.2/go.mod h1:CtFcLAvnDT1KcuF1JBb/DJj0KplY8jHryO06KzQ1hsQ= github.com/TwiN/go-color v1.4.1 h1:mqG0P/KBgHKVqmtL5ye7K0/Gr4l6hTksPgTgMk3mUzc= github.com/TwiN/go-color v1.4.1/go.mod h1:WcPf/jtiW95WBIsEeY1Lc/b8aaWoiqQpu5cf8WFxu+s= github.com/adrg/xdg v0.5.3 h1:xRnxJXne7+oWDatRhR1JLnvuccuIeCoBu2rtuLqQB78= github.com/adrg/xdg v0.5.3/go.mod h1:nlTsY+NNiCBGCK2tpm09vRqfVzrc2fLmXGpBLF0zlTQ= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/apache/arrow-go/v18 v18.1.0 h1:agLwJUiVuwXZdwPYVrlITfx7bndULJ/dggbnLFgDp/Y= +github.com/apache/arrow-go/v18 v18.1.0/go.mod h1:tigU/sIgKNXaesf5d7Y95jBBKS5KsxTqYBKXFsvKzo0= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/aws/aws-sdk-go-v2 v1.41.1 h1:ABlyEARCDLN034NhxlRUSZr4l71mh+T5KAeGh6cerhU= @@ -115,6 +120,8 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= +github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ= github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= @@ -127,6 +134,8 @@ github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/flatbuffers v25.1.24+incompatible h1:4wPqL3K7GzBd1CwyhSd3usxLKOaJN/AC6puCca6Jm7o= +github.com/google/flatbuffers v25.1.24+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -153,6 +162,8 @@ github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8Hm github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ= github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M= +github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= +github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kluctl/go-embed-python v0.0.0-3.13.1-20241219-1 h1:x1cSEj4Ug5mpuZgUHLvUmlc5r//KHFn6iYiRSrRcVy4= @@ -179,6 +190,8 @@ github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w= github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/mailru/easyjson v0.9.1 h1:LbtsOm5WAswyWbvTEOqhypdPeZzHavpZx96/n553mR8= github.com/mailru/easyjson v0.9.1/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/marcboeker/go-duckdb v1.8.5 h1:tkYp+TANippy0DaIOP5OEfBEwbUINqiFqgwMQ44jME0= +github.com/marcboeker/go-duckdb v1.8.5/go.mod h1:6mK7+WQE4P4u5AFLvVBmhFxY5fvhymFptghgJX6B+/8= github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= github.com/matryer/is v1.4.1 h1:55ehd8zaGABKLXQUe2awZ99BD/PTc2ls+KV/dXphgEQ= github.com/matryer/is v1.4.1/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= @@ -206,10 +219,18 @@ github.com/ollama/ollama v0.15.4 h1:y841GH5lsi5j5BTFyX/E+UOC3Yiw+JBfdjBVRGw+I0M= github.com/ollama/ollama v0.15.4/go.mod h1:4Yn3jw2hZ4VqyJ1XciYawDRE8bzv4RT3JiVZR1kCfwE= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/parquet-go/bitpack v1.0.0 h1:AUqzlKzPPXf2bCdjfj4sTeacrUwsT7NlcYDMUQxPcQA= +github.com/parquet-go/bitpack v1.0.0/go.mod h1:XnVk9TH+O40eOOmvpAVZ7K2ocQFrQwysLMnc6M/8lgs= +github.com/parquet-go/jsonlite v1.0.0 h1:87QNdi56wOfsE5bdgas0vRzHPxfJgzrXGml1zZdd7VU= +github.com/parquet-go/jsonlite v1.0.0/go.mod h1:nDjpkpL4EOtqs6NQugUsi0Rleq9sW/OtC1NnZEnxzF0= +github.com/parquet-go/parquet-go v0.27.0 h1:vHWK2xaHbj+v1DYps03yDRpEsdtOeKbhiXUaixoPb3g= +github.com/parquet-go/parquet-go v0.27.0/go.mod h1:navtkAYr2LGoJVp141oXPlO/sxLvaOe3la2JEoD8+rg= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= +github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= +github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0= github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -274,6 +295,8 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twpayne/go-geom v1.6.1 h1:iLE+Opv0Ihm/ABIcvQFGIiFBXd76oBIar9drAwHFhR4= +github.com/twpayne/go-geom v1.6.1/go.mod h1:Kr+Nly6BswFsKM5sd31YaoWS5PeDDH2NftJTK7Gd028= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= @@ -292,10 +315,13 @@ github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQs github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5FYc= github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= +github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= @@ -314,6 +340,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= @@ -323,6 +350,7 @@ golang.org/x/exp v0.0.0-20260112195511-716be5621a96/go.mod h1:nzimsREAkjBCIEFtHi golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= @@ -346,11 +374,14 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2 h1:O1cMQHRfwNpDfDJerqRoE2oD+AFlyid87D40L/OkkJo= +golang.org/x/telemetry v0.0.0-20260109210033-bd525da824e2/go.mod h1:b7fPSJ0pKZ3ccUh8gnTONJxhn3c/PS6tyzQvyqw4iA8= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= @@ -359,6 +390,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= diff --git a/pkg/ml/agent.go b/pkg/ml/agent.go new file mode 100644 index 0000000..8e13832 --- /dev/null +++ b/pkg/ml/agent.go @@ -0,0 +1,1070 @@ +package ml + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "regexp" + "sort" + "strings" + "time" +) + +// AgentConfig holds scoring agent configuration. +type AgentConfig struct { + M3Host string + M3User string + M3SSHKey string + M3AdapterBase string + InfluxURL string + InfluxDB string + DBPath string + APIURL string + JudgeURL string + JudgeModel string + Model string + BaseModel string + PollInterval int + WorkDir string + Filter string + Force bool + OneShot bool + DryRun bool +} + +// Checkpoint represents a discovered adapter checkpoint on M3. +type Checkpoint struct { + RemoteDir string + Filename string + Dirname string + Iteration int + ModelTag string + Label string + RunID string +} + +// ProbeResult holds the result of running all probes against a checkpoint. +type ProbeResult struct { + Accuracy float64 `json:"accuracy"` + Correct int `json:"correct"` + Total int `json:"total"` + ByCategory map[string]CategoryResult `json:"by_category"` + Probes map[string]SingleProbeResult `json:"probes"` +} + +// CategoryResult holds pass/fail counts for a probe category. +type CategoryResult struct { + Correct int `json:"correct"` + Total int `json:"total"` +} + +// SingleProbeResult holds the result of a single probe. +type SingleProbeResult struct { + Passed bool `json:"passed"` + Response string `json:"response"` +} + +// bufferEntry is a JSONL-buffered result for when InfluxDB is down. +type bufferEntry struct { + Checkpoint Checkpoint `json:"checkpoint"` + Results ProbeResult `json:"results"` + Timestamp string `json:"timestamp"` +} + +// BaseModelMap maps model tags to their HuggingFace/local model paths. +var BaseModelMap = map[string]string{ + "gemma-3-1b": "mlx-community/gemma-3-1b-it-4bit", + "gemma-3-4b": "mlx-community/gemma-3-4b-it-4bit", + "gemma-3-12b": "mlx-community/gemma-3-12b-it-4bit", + "gemma-3-27b": "mlx-community/gemma-3-27b-it-qat-4bit", + "gpt-oss-20b": "/Volumes/Data/lem/models/gpt-oss-20b-mlx", +} + +// ModelFamilies identifies known model families from adapter directory names. +var ModelFamilies = []struct { + DirPrefix string + Tag string + Short string +}{ + {"deepseek-r1-7b", "deepseek-r1-7b", "R1"}, + {"27b-", "gemma-3-27b", "G27"}, + {"27b", "gemma-3-27b", "G27"}, + {"15k/gemma-3-27b", "gemma-3-27b", "G27"}, + {"15k/gemma-3-12b", "gemma-3-12b", "G12"}, + {"15k/gemma-3-1b", "gemma-3-1b", "G1"}, + {"12b", "gemma-3-12b", "G12"}, + {"1b-", "gemma-3-1b", "G1"}, + {"1b", "gemma-3-1b", "G1"}, + {"4b", "gemma-3-4b", "G4"}, + {"vi-12b", "gemma-3-12b", "Vi12"}, + {"vi", "gemma-3-1b", "Vi1"}, + {"gpt-oss", "gpt-oss-20b", "GPT"}, + {"lem-gpt-oss", "gpt-oss-20b", "LGPT"}, + {"bench-1b", "gemma-3-1b", "B1"}, + {"book", "gemma-3-27b", "Book"}, + {"cross", "gemma-3-12b", "Cross"}, +} + +// AdapterMeta maps an adapter directory name to (model_tag, label_prefix, run_id_stem). +func AdapterMeta(dirname string) (string, string, string) { + name := strings.TrimPrefix(dirname, "adapters-") + + for _, fam := range ModelFamilies { + if strings.HasPrefix(name, fam.DirPrefix) { + variant := strings.TrimPrefix(name, fam.DirPrefix) + variant = strings.TrimLeft(variant, "-") + if variant == "" { + variant = "base" + } + short := fam.Short + "-" + variant + if variant == "base" { + short = fam.Short + } + stem := strings.ReplaceAll(name, "/", "-") + return fam.Tag, short, stem + } + } + + short := name + if len(short) > 10 { + short = short[:10] + } + return name, short, name +} + +// RunAgentLoop is the main scoring agent loop. +func RunAgentLoop(cfg *AgentConfig) { + log.Println(strings.Repeat("=", 60)) + log.Println("ROCm Scoring Agent — Go Edition") + log.Printf("M3: %s@%s", cfg.M3User, cfg.M3Host) + log.Printf("Inference API: %s", cfg.APIURL) + log.Printf("Judge API: %s (%s)", cfg.JudgeURL, cfg.JudgeModel) + log.Printf("InfluxDB: %s/%s", cfg.InfluxURL, cfg.InfluxDB) + if cfg.DBPath != "" { + log.Printf("DuckDB: %s", cfg.DBPath) + } + log.Printf("Poll interval: %ds", cfg.PollInterval) + log.Println(strings.Repeat("=", 60)) + + influx := NewInfluxClient(cfg.InfluxURL, cfg.InfluxDB) + os.MkdirAll(cfg.WorkDir, 0755) + + for { + ReplayInfluxBuffer(cfg.WorkDir, influx) + + log.Println("Discovering checkpoints on M3...") + checkpoints, err := DiscoverCheckpoints(cfg) + if err != nil { + log.Printf("Discovery failed: %v", err) + if cfg.OneShot { + return + } + time.Sleep(time.Duration(cfg.PollInterval) * time.Second) + continue + } + log.Printf("Found %d total checkpoints", len(checkpoints)) + + var unscored []Checkpoint + if cfg.Force { + unscored = checkpoints + log.Printf("Force mode: scoring all %d checkpoints", len(unscored)) + } else { + scored, err := GetScoredLabels(influx) + if err != nil { + log.Printf("InfluxDB query failed: %v", err) + } + log.Printf("Already scored: %d (run_id, label) pairs", len(scored)) + unscored = FindUnscored(checkpoints, scored) + log.Printf("Unscored: %d checkpoints", len(unscored)) + } + + if len(unscored) == 0 { + log.Printf("Nothing to score. Sleeping %ds...", cfg.PollInterval) + if cfg.OneShot { + return + } + time.Sleep(time.Duration(cfg.PollInterval) * time.Second) + continue + } + + targets := unscored + if !cfg.Force { + targets = unscored[:1] + } + + for i, target := range targets { + log.Printf("Grabbed: %s (%s) [%d/%d]", target.Label, target.Dirname, i+1, len(targets)) + + if cfg.DryRun { + log.Printf("[DRY RUN] Would process: %s/%s", target.Dirname, target.Filename) + continue + } + + if err := ProcessOne(cfg, influx, target); err != nil { + log.Printf("Error processing %s: %v", target.Label, err) + } + time.Sleep(5 * time.Second) + } + + if cfg.DryRun || cfg.OneShot { + return + } + } +} + +// DiscoverCheckpoints lists all adapter directories and checkpoint files on M3 via SSH. +func DiscoverCheckpoints(cfg *AgentConfig) ([]Checkpoint, error) { + pattern := "adapters-*" + if cfg.Filter != "" { + pattern = "adapters-" + cfg.Filter + "*" + } + out, err := SSHCommand(cfg, fmt.Sprintf("ls -d %s/%s 2>/dev/null", cfg.M3AdapterBase, pattern)) + if err != nil { + return nil, fmt.Errorf("list adapter dirs: %w", err) + } + + var checkpoints []Checkpoint + iterRe := regexp.MustCompile(`(\d+)`) + + var adapterDirs []string + for _, dirpath := range strings.Split(strings.TrimSpace(out), "\n") { + if dirpath == "" { + continue + } + subOut, subErr := SSHCommand(cfg, fmt.Sprintf("ls -d %s/gemma-3-* 2>/dev/null", dirpath)) + if subErr == nil && strings.TrimSpace(subOut) != "" { + for _, sub := range strings.Split(strings.TrimSpace(subOut), "\n") { + if sub != "" { + adapterDirs = append(adapterDirs, sub) + } + } + } else { + adapterDirs = append(adapterDirs, dirpath) + } + } + + for _, dirpath := range adapterDirs { + dirname := strings.TrimPrefix(dirpath, cfg.M3AdapterBase+"/") + + filesOut, err := SSHCommand(cfg, fmt.Sprintf("ls %s/*_adapters.safetensors 2>/dev/null", dirpath)) + if err != nil { + continue + } + + for _, fp := range strings.Split(strings.TrimSpace(filesOut), "\n") { + if fp == "" { + continue + } + filename := fileBase(fp) + + match := iterRe.FindStringSubmatch(filename) + if len(match) < 2 { + continue + } + iteration := 0 + fmt.Sscanf(match[1], "%d", &iteration) + + modelTag, labelPrefix, stem := AdapterMeta(dirname) + label := fmt.Sprintf("%s @%s", labelPrefix, match[1]) + runID := fmt.Sprintf("%s-capability-auto", stem) + + checkpoints = append(checkpoints, Checkpoint{ + RemoteDir: dirpath, + Filename: filename, + Dirname: dirname, + Iteration: iteration, + ModelTag: modelTag, + Label: label, + RunID: runID, + }) + } + } + + return checkpoints, nil +} + +// GetScoredLabels returns all (run_id, label) pairs already scored in InfluxDB. +func GetScoredLabels(influx *InfluxClient) (map[[2]string]bool, error) { + rows, err := influx.QuerySQL("SELECT DISTINCT run_id, label FROM capability_score") + if err != nil { + return nil, err + } + + scored := make(map[[2]string]bool) + for _, row := range rows { + runID, _ := row["run_id"].(string) + label, _ := row["label"].(string) + if runID != "" && label != "" { + scored[[2]string{runID, label}] = true + } + } + return scored, nil +} + +// FindUnscored filters checkpoints to only unscored ones, sorted by (dirname, iteration). +func FindUnscored(checkpoints []Checkpoint, scored map[[2]string]bool) []Checkpoint { + var unscored []Checkpoint + for _, c := range checkpoints { + if !scored[[2]string{c.RunID, c.Label}] { + unscored = append(unscored, c) + } + } + sort.Slice(unscored, func(i, j int) bool { + if unscored[i].Dirname != unscored[j].Dirname { + return unscored[i].Dirname < unscored[j].Dirname + } + return unscored[i].Iteration < unscored[j].Iteration + }) + return unscored +} + +// isMLXNative returns true if this model can be served directly on M3 via +// mlx_lm.server with --adapter, avoiding the MLX→PEFT conversion step. +func isMLXNative(modelTag string) bool { + return strings.HasPrefix(modelTag, "gemma-3-") || strings.HasPrefix(modelTag, "gpt-oss") +} + +// ProcessOne fetches, converts, scores, and pushes one checkpoint. +func ProcessOne(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error { + log.Println(strings.Repeat("=", 60)) + log.Printf("Processing: %s / %s [%s]", cp.Dirname, cp.Filename, cp.ModelTag) + log.Println(strings.Repeat("=", 60)) + + if isMLXNative(cp.ModelTag) { + return processMLXNative(cfg, influx, cp) + } + return processWithConversion(cfg, influx, cp) +} + +// processMLXNative scores a checkpoint using Ollama on M3. +func processMLXNative(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error { + ollamaBase, ok := OllamaBaseModelMap[cp.ModelTag] + if !ok { + return fmt.Errorf("unknown Ollama model for tag %s", cp.ModelTag) + } + hfBase := HFBaseModelMap[cp.ModelTag] + if hfBase == "" { + hfBase = ollamaBase + } + + tempModel := fmt.Sprintf("lem-%s-%d", cp.ModelTag, cp.Iteration) + localAdapterDir := filepath.Join(cfg.WorkDir, "adapter-"+cp.Dirname) + peftDir := filepath.Join(cfg.WorkDir, "peft-"+cp.Dirname) + + os.MkdirAll(localAdapterDir, 0755) + + defer func() { + os.RemoveAll(localAdapterDir) + os.RemoveAll(peftDir) + OllamaDeleteModel(cfg.JudgeURL, tempModel) + }() + + log.Printf("Fetching adapter from M3 (%s)...", cp.Filename) + remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename) + remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir) + localSF := filepath.Join(localAdapterDir, cp.Filename) + localCfg := filepath.Join(localAdapterDir, "adapter_config.json") + + if err := SCPFrom(cfg, remoteSF, localSF); err != nil { + return fmt.Errorf("scp safetensors: %w", err) + } + if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil { + return fmt.Errorf("scp config: %w", err) + } + + log.Println("Converting MLX → PEFT format...") + if err := ConvertMLXtoPEFT(localSF, localCfg, peftDir, hfBase); err != nil { + return fmt.Errorf("convert adapter: %w", err) + } + + log.Printf("Creating Ollama model %s (base: %s)...", tempModel, ollamaBase) + if err := OllamaCreateModel(cfg.JudgeURL, tempModel, ollamaBase, peftDir); err != nil { + return fmt.Errorf("ollama create: %w", err) + } + log.Printf("Ollama model %s ready", tempModel) + + ctx := context.Background() + probeBackend := NewHTTPBackend(cfg.JudgeURL, tempModel) + + const baseTS int64 = 1739577600 + results, fullResponses := RunCapabilityProbesFull(ctx, probeBackend, func(probeID, category string, passed bool, response string, correct, total int) { + passedInt := 0 + if passed { + passedInt = 1 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(total+100)) * 1_000_000_000 + line := fmt.Sprintf( + "probe_score,model=%s,run_id=%s,label=%s,probe_id=%s passed=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), EscapeLp(probeID), + passedInt, cp.Iteration, ts, + ) + if err := influx.WriteLp([]string{line}); err != nil { + log.Printf(" [%s] InfluxDB stream failed: %v", probeID, err) + } + }) + + log.Printf("Capability: %s -- %.1f%% (%d/%d)", + cp.Label, results.Accuracy, results.Correct, results.Total) + + if err := PushCapabilitySummary(influx, cp, results); err != nil { + log.Printf("InfluxDB summary push failed, buffering: %v", err) + BufferInfluxResult(cfg.WorkDir, cp, results) + } + PushCapabilityResultsDB(cfg.DBPath, cp, results) + + judgeBackend := NewHTTPBackend(cfg.JudgeURL, cfg.JudgeModel) + judge := NewJudge(judgeBackend) + + log.Println("Judging 23 capability responses (0-10 quality scoring)...") + ScoreCapabilityAndPush(ctx, judge, influx, cp, fullResponses) + + log.Println("Running 6 content probes (0-10 judge scoring)...") + contentResponses := RunContentProbesViaAPI(ctx, probeBackend) + if len(contentResponses) > 0 { + contentRunID := strings.Replace(cp.RunID, "-capability-", "-content-", 1) + ScoreContentAndPush(ctx, judge, influx, cp, contentRunID, contentResponses) + } + + return nil +} + +// processWithConversion fetches adapter locally, converts MLX→PEFT, and scores. +func processWithConversion(cfg *AgentConfig, influx *InfluxClient, cp Checkpoint) error { + localAdapterDir := filepath.Join(cfg.WorkDir, cp.Dirname) + os.MkdirAll(localAdapterDir, 0755) + + localSF := filepath.Join(localAdapterDir, cp.Filename) + localCfg := filepath.Join(localAdapterDir, "adapter_config.json") + + defer func() { + os.Remove(localSF) + os.Remove(localCfg) + peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + os.RemoveAll(peftDir) + }() + + log.Println("Fetching adapter from M3...") + remoteSF := fmt.Sprintf("%s/%s", cp.RemoteDir, cp.Filename) + remoteCfg := fmt.Sprintf("%s/adapter_config.json", cp.RemoteDir) + + if err := SCPFrom(cfg, remoteSF, localSF); err != nil { + return fmt.Errorf("scp safetensors: %w", err) + } + if err := SCPFrom(cfg, remoteCfg, localCfg); err != nil { + return fmt.Errorf("scp config: %w", err) + } + + log.Println("Converting MLX to PEFT format...") + peftDir := filepath.Join(cfg.WorkDir, fmt.Sprintf("peft_%07d", cp.Iteration)) + if err := ConvertMLXtoPEFT(localSF, localCfg, peftDir, cfg.BaseModel); err != nil { + return fmt.Errorf("convert adapter: %w", err) + } + + log.Println("Running 23 capability probes...") + ctx := context.Background() + modelName := cfg.Model + if modelName == "" { + modelName = cp.ModelTag + } + backend := NewHTTPBackend(cfg.APIURL, modelName) + + results := RunCapabilityProbes(ctx, backend) + + log.Printf("Result: %s -- %.1f%% (%d/%d)", + cp.Label, results.Accuracy, results.Correct, results.Total) + + if err := PushCapabilityResults(influx, cp, results); err != nil { + log.Printf("InfluxDB push failed, buffering: %v", err) + BufferInfluxResult(cfg.WorkDir, cp, results) + } + PushCapabilityResultsDB(cfg.DBPath, cp, results) + + return nil +} + +// ProbeCallback is called after each probe completes for real-time streaming. +type ProbeCallback func(probeID, category string, passed bool, response string, correct, total int) + +// RunCapabilityProbes runs all 23 probes against a backend. +func RunCapabilityProbes(ctx context.Context, backend Backend) ProbeResult { + results := ProbeResult{ + ByCategory: make(map[string]CategoryResult), + Probes: make(map[string]SingleProbeResult), + } + + correct := 0 + total := 0 + + for _, probe := range CapabilityProbes { + response, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: 0.1, MaxTokens: 500}) + if err != nil { + log.Printf(" [%s] ERROR: %v", probe.ID, err) + results.Probes[probe.ID] = SingleProbeResult{Passed: false, Response: err.Error()} + total++ + cat := results.ByCategory[probe.Category] + cat.Total++ + results.ByCategory[probe.Category] = cat + continue + } + + clean := StripThinkBlocks(response) + passed := probe.Check(clean) + total++ + if passed { + correct++ + } + + cat := results.ByCategory[probe.Category] + cat.Total++ + if passed { + cat.Correct++ + } + results.ByCategory[probe.Category] = cat + + stored := clean + if len(stored) > 300 { + stored = stored[:300] + } + results.Probes[probe.ID] = SingleProbeResult{Passed: passed, Response: stored} + + status := "FAIL" + if passed { + status = "PASS" + } + log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer) + } + + if total > 0 { + results.Accuracy = float64(correct) / float64(total) * 100 + } + results.Correct = correct + results.Total = total + + return results +} + +// CapResponseEntry holds a capability probe response with its metadata for judge scoring. +type CapResponseEntry struct { + ProbeID string + Category string + Prompt string + Answer string + Response string + Passed bool +} + +// RunCapabilityProbesFull runs all probes via a backend and returns both +// aggregate results and full responses for judge scoring. +func RunCapabilityProbesFull(ctx context.Context, backend Backend, onProbe ProbeCallback) (ProbeResult, []CapResponseEntry) { + results := ProbeResult{ + ByCategory: make(map[string]CategoryResult), + Probes: make(map[string]SingleProbeResult), + } + var fullResponses []CapResponseEntry + + correct := 0 + total := 0 + + for _, probe := range CapabilityProbes { + response, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: 0.1, MaxTokens: 500}) + if err != nil { + log.Printf(" [%s] ERROR: %v", probe.ID, err) + response = fmt.Sprintf("ERROR: %v", err) + } + + clean := StripThinkBlocks(response) + passed := probe.Check(clean) + total++ + if passed { + correct++ + } + + cat := results.ByCategory[probe.Category] + cat.Total++ + if passed { + cat.Correct++ + } + results.ByCategory[probe.Category] = cat + + stored := clean + if len(stored) > 300 { + stored = stored[:300] + } + results.Probes[probe.ID] = SingleProbeResult{Passed: passed, Response: stored} + + fullResponses = append(fullResponses, CapResponseEntry{ + ProbeID: probe.ID, + Category: probe.Category, + Prompt: probe.Prompt, + Answer: probe.Answer, + Response: clean, + Passed: passed, + }) + + status := "FAIL" + if passed { + status = "PASS" + } + log.Printf(" [%s] %s (expected: %s)", probe.ID, status, probe.Answer) + + if onProbe != nil { + onProbe(probe.ID, probe.Category, passed, stored, correct, total) + } + } + + if total > 0 { + results.Accuracy = float64(correct) / float64(total) * 100 + } + results.Correct = correct + results.Total = total + + return results, fullResponses +} + +// ContentResponse holds a content probe response for later judging. +type ContentResponse struct { + Probe ContentProbe + Response string +} + +// RunContentProbesViaAPI runs content probes via a backend. +func RunContentProbesViaAPI(ctx context.Context, backend Backend) []ContentResponse { + var responses []ContentResponse + + for _, probe := range ContentProbes { + reply, err := backend.Generate(ctx, probe.Prompt, GenOpts{Temperature: 0.7, MaxTokens: 1000}) + if err != nil { + log.Printf(" [content:%s] ERROR: %v", probe.ID, err) + continue + } + + reply = StripThinkBlocks(reply) + log.Printf(" [content:%s] got %d chars", probe.ID, len(reply)) + + responses = append(responses, ContentResponse{ + Probe: probe, + Response: reply, + }) + } + + return responses +} + +// RunContentProbesViaRunner sends content probes through an SSH probe runner. +func RunContentProbesViaRunner(stdin io.WriteCloser, scanner *bufio.Scanner) []ContentResponse { + var responses []ContentResponse + + for _, probe := range ContentProbes { + req := map[string]interface{}{ + "prompt": probe.Prompt, + "max_tokens": 1000, + "temp": 0.7, + } + reqJSON, _ := json.Marshal(req) + fmt.Fprintf(stdin, "%s\n", reqJSON) + + var response string + if scanner.Scan() { + var resp probeRunnerResponse + if err := json.Unmarshal(scanner.Bytes(), &resp); err != nil { + log.Printf(" [content:%s] parse error: %v", probe.ID, err) + continue + } else if resp.Error != "" { + log.Printf(" [content:%s] ERROR: %s", probe.ID, resp.Error) + continue + } else { + response = resp.Response + } + } else { + log.Printf(" [content:%s] no response from runner", probe.ID) + continue + } + + response = StripThinkBlocks(response) + log.Printf(" [content:%s] got %d chars", probe.ID, len(response)) + + responses = append(responses, ContentResponse{ + Probe: probe, + Response: response, + }) + } + + return responses +} + +// probeRunnerResponse is the JSON response from the Python probe runner. +type probeRunnerResponse struct { + Response string `json:"response"` + Error string `json:"error"` + Elapsed float64 `json:"elapsed"` +} + +// ScoreCapabilityAndPush judges each capability response via LLM and pushes scores to InfluxDB. +func ScoreCapabilityAndPush(ctx context.Context, judge *Judge, influx *InfluxClient, cp Checkpoint, responses []CapResponseEntry) { + const baseTS int64 = 1739577600 + var lines []string + + for i, cr := range responses { + scores, err := judge.ScoreCapability(ctx, cr.Prompt, cr.Answer, cr.Response) + if err != nil { + log.Printf(" [%s] judge error: %v", cr.ProbeID, err) + continue + } + + avg := (scores.Reasoning + scores.Correctness + scores.Clarity) / 3.0 + log.Printf(" [%s] judge: R=%.1f C=%.1f Cl=%.1f avg=%.2f", + cr.ProbeID, scores.Reasoning, scores.Correctness, scores.Clarity, avg) + + ts := (baseTS + int64(cp.Iteration)*1000 + int64(i)) * 1_000_000_000 + line := fmt.Sprintf( + "capability_judge,model=%s,run_id=%s,label=%s,probe_id=%s,category=%s reasoning=%.2f,correctness=%.2f,clarity=%.2f,avg=%.2f,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), + EscapeLp(cr.ProbeID), EscapeLp(cr.Category), + scores.Reasoning, scores.Correctness, scores.Clarity, avg, cp.Iteration, ts, + ) + lines = append(lines, line) + } + + if len(lines) > 0 { + if err := influx.WriteLp(lines); err != nil { + log.Printf("InfluxDB capability_judge push failed: %v", err) + } else { + log.Printf("Pushed %d capability judge scores to InfluxDB for %s", len(lines), cp.Label) + } + } +} + +// ScoreContentAndPush scores content responses via judge and pushes scores to InfluxDB. +func ScoreContentAndPush(ctx context.Context, judge *Judge, influx *InfluxClient, cp Checkpoint, runID string, responses []ContentResponse) { + const baseTS int64 = 1739577600 + dims := []string{"ccp_compliance", "truth_telling", "engagement", "axiom_integration", "sovereignty_reasoning", "emotional_register"} + + for i, cr := range responses { + scores, err := judge.ScoreContent(ctx, cr.Probe, cr.Response) + if err != nil { + log.Printf(" [content:%s] judge error: %v", cr.Probe.ID, err) + continue + } + + log.Printf(" [content:%s] ccp=%d truth=%d engage=%d axiom=%d sov=%d emot=%d", + cr.Probe.ID, + scores.CCPCompliance, scores.TruthTelling, scores.Engagement, + scores.AxiomIntegration, scores.SovereigntyReasoning, scores.EmotionalRegister) + + scoreMap := map[string]int{ + "ccp_compliance": scores.CCPCompliance, + "truth_telling": scores.TruthTelling, + "engagement": scores.Engagement, + "axiom_integration": scores.AxiomIntegration, + "sovereignty_reasoning": scores.SovereigntyReasoning, + "emotional_register": scores.EmotionalRegister, + } + + var lines []string + for j, dim := range dims { + val := scoreMap[dim] + ts := (baseTS + int64(cp.Iteration)*1000 + int64(i*10+j)) * 1_000_000_000 + line := fmt.Sprintf( + "content_score,model=%s,run_id=%s,label=%s,dimension=%s,has_kernel=true score=%d,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(runID), EscapeLp(cp.Label), EscapeLp(dim), + val, cp.Iteration, ts, + ) + lines = append(lines, line) + } + + if err := influx.WriteLp(lines); err != nil { + log.Printf(" [content:%s] InfluxDB push failed: %v", cr.Probe.ID, err) + } + } + + log.Printf("Content scoring done for %s: %d probes × %d dimensions", cp.Label, len(responses), len(dims)) +} + +// PushCapabilitySummary pushes overall + per-category scores to InfluxDB. +func PushCapabilitySummary(influx *InfluxClient, cp Checkpoint, results ProbeResult) error { + const baseTS int64 = 1739577600 + + var lines []string + + ts := (baseTS + int64(cp.Iteration)*1000 + 0) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), + results.Accuracy, results.Correct, results.Total, cp.Iteration, ts, + )) + + cats := make([]string, 0, len(results.ByCategory)) + for cat := range results.ByCategory { + cats = append(cats, cat) + } + sort.Strings(cats) + + for i, cat := range cats { + data := results.ByCategory[cat] + catAcc := 0.0 + if data.Total > 0 { + catAcc = float64(data.Correct) / float64(data.Total) * 100 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(i+1)) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), EscapeLp(cat), + catAcc, data.Correct, data.Total, cp.Iteration, ts, + )) + } + + if err := influx.WriteLp(lines); err != nil { + return err + } + log.Printf("Pushed %d summary points to InfluxDB for %s", len(lines), cp.Label) + return nil +} + +// PushCapabilityResults pushes all results (overall + categories + probes) in one batch. +func PushCapabilityResults(influx *InfluxClient, cp Checkpoint, results ProbeResult) error { + const baseTS int64 = 1739577600 + + var lines []string + + ts := (baseTS + int64(cp.Iteration)*1000 + 0) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=overall accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), + results.Accuracy, results.Correct, results.Total, cp.Iteration, ts, + )) + + cats := make([]string, 0, len(results.ByCategory)) + for cat := range results.ByCategory { + cats = append(cats, cat) + } + sort.Strings(cats) + + for i, cat := range cats { + data := results.ByCategory[cat] + catAcc := 0.0 + if data.Total > 0 { + catAcc = float64(data.Correct) / float64(data.Total) * 100 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(i+1)) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "capability_score,model=%s,run_id=%s,label=%s,category=%s accuracy=%.1f,correct=%di,total=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), EscapeLp(cat), + catAcc, data.Correct, data.Total, cp.Iteration, ts, + )) + } + + probeIDs := make([]string, 0, len(results.Probes)) + for id := range results.Probes { + probeIDs = append(probeIDs, id) + } + sort.Strings(probeIDs) + + for j, probeID := range probeIDs { + probeRes := results.Probes[probeID] + passedInt := 0 + if probeRes.Passed { + passedInt = 1 + } + ts := (baseTS + int64(cp.Iteration)*1000 + int64(j+100)) * 1_000_000_000 + lines = append(lines, fmt.Sprintf( + "probe_score,model=%s,run_id=%s,label=%s,probe_id=%s passed=%di,iteration=%di %d", + EscapeLp(cp.ModelTag), EscapeLp(cp.RunID), EscapeLp(cp.Label), EscapeLp(probeID), + passedInt, cp.Iteration, ts, + )) + } + + if err := influx.WriteLp(lines); err != nil { + return err + } + log.Printf("Pushed %d points to InfluxDB for %s", len(lines), cp.Label) + return nil +} + +// PushCapabilityResultsDB writes scoring results to DuckDB for persistent storage. +func PushCapabilityResultsDB(dbPath string, cp Checkpoint, results ProbeResult) { + if dbPath == "" { + return + } + + db, err := OpenDBReadWrite(dbPath) + if err != nil { + log.Printf("DuckDB dual-write: open failed: %v", err) + return + } + defer db.Close() + + db.EnsureScoringTables() + + _, err = db.conn.Exec( + `INSERT OR REPLACE INTO checkpoint_scores (model, run_id, label, iteration, correct, total, accuracy) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + cp.ModelTag, cp.RunID, cp.Label, cp.Iteration, + results.Correct, results.Total, results.Accuracy, + ) + if err != nil { + log.Printf("DuckDB dual-write: checkpoint_scores insert: %v", err) + } + + for probeID, probeRes := range results.Probes { + db.conn.Exec( + `INSERT OR REPLACE INTO probe_results (model, run_id, label, probe_id, passed, response, iteration) + VALUES (?, ?, ?, ?, ?, ?, ?)`, + cp.ModelTag, cp.RunID, cp.Label, probeID, + probeRes.Passed, probeRes.Response, cp.Iteration, + ) + } + + log.Printf("DuckDB: wrote %d probe results for %s", len(results.Probes)+1, cp.Label) +} + +// BufferInfluxResult saves results to a local JSONL file when InfluxDB is down. +func BufferInfluxResult(workDir string, cp Checkpoint, results ProbeResult) { + bufPath := filepath.Join(workDir, "influx_buffer.jsonl") + f, err := os.OpenFile(bufPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + log.Printf("Cannot open buffer file: %v", err) + return + } + defer f.Close() + + entry := bufferEntry{ + Checkpoint: cp, + Results: results, + Timestamp: time.Now().UTC().Format(time.RFC3339), + } + data, _ := json.Marshal(entry) + f.Write(append(data, '\n')) + log.Printf("Buffered results to %s", bufPath) +} + +// ReplayInfluxBuffer retries pushing buffered results to InfluxDB. +func ReplayInfluxBuffer(workDir string, influx *InfluxClient) { + bufPath := filepath.Join(workDir, "influx_buffer.jsonl") + data, err := os.ReadFile(bufPath) + if err != nil { + return + } + + var remaining []string + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if line == "" { + continue + } + var entry bufferEntry + if err := json.Unmarshal([]byte(line), &entry); err != nil { + remaining = append(remaining, line) + continue + } + if err := PushCapabilityResults(influx, entry.Checkpoint, entry.Results); err != nil { + remaining = append(remaining, line) + } else { + log.Printf("Replayed buffered result: %s", entry.Checkpoint.Label) + } + } + + if len(remaining) > 0 { + os.WriteFile(bufPath, []byte(strings.Join(remaining, "\n")+"\n"), 0644) + } else { + os.Remove(bufPath) + log.Println("Buffer fully replayed and cleared") + } +} + +// SSHCommand executes a command on M3 via SSH. +func SSHCommand(cfg *AgentConfig, cmd string) (string, error) { + sshArgs := []string{ + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", cfg.M3SSHKey, + fmt.Sprintf("%s@%s", cfg.M3User, cfg.M3Host), + cmd, + } + result, err := exec.Command("ssh", sshArgs...).CombinedOutput() + if err != nil { + return "", fmt.Errorf("ssh %q: %w: %s", cmd, err, strings.TrimSpace(string(result))) + } + return string(result), nil +} + +// SCPFrom copies a file from M3 to a local path. +func SCPFrom(cfg *AgentConfig, remotePath, localPath string) error { + os.MkdirAll(filepath.Dir(localPath), 0755) + scpArgs := []string{ + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", cfg.M3SSHKey, + fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath), + localPath, + } + result, err := exec.Command("scp", scpArgs...).CombinedOutput() + if err != nil { + return fmt.Errorf("scp %s: %w: %s", remotePath, err, strings.TrimSpace(string(result))) + } + return nil +} + +// SCPTo copies a local file to M3. +func SCPTo(cfg *AgentConfig, localPath, remotePath string) error { + scpArgs := []string{ + "-o", "ConnectTimeout=10", + "-o", "BatchMode=yes", + "-o", "StrictHostKeyChecking=no", + "-i", cfg.M3SSHKey, + localPath, + fmt.Sprintf("%s@%s:%s", cfg.M3User, cfg.M3Host, remotePath), + } + result, err := exec.Command("scp", scpArgs...).CombinedOutput() + if err != nil { + return fmt.Errorf("scp to %s: %w: %s", remotePath, err, strings.TrimSpace(string(result))) + } + return nil +} + +// fileBase returns the last component of a path. +func fileBase(path string) string { + if i := strings.LastIndexAny(path, "/\\"); i >= 0 { + return path[i+1:] + } + return path +} + +// EnvOr returns the environment variable value or a fallback. +func EnvOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +// IntEnvOr returns the integer environment variable value or a fallback. +func IntEnvOr(key string, fallback int) int { + v := os.Getenv(key) + if v == "" { + return fallback + } + var n int + fmt.Sscanf(v, "%d", &n) + if n == 0 { + return fallback + } + return n +} + +// ExpandHome expands ~ to the user's home directory. +func ExpandHome(path string) string { + if strings.HasPrefix(path, "~/") { + home, err := os.UserHomeDir() + if err == nil { + return filepath.Join(home, path[2:]) + } + } + return path +} diff --git a/pkg/ml/convert.go b/pkg/ml/convert.go new file mode 100644 index 0000000..efc61ac --- /dev/null +++ b/pkg/ml/convert.go @@ -0,0 +1,303 @@ +package ml + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "math" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" +) + +var ( + loraARe = regexp.MustCompile(`\.lora_a$`) + loraBRe = regexp.MustCompile(`\.lora_b$`) + layerRe = regexp.MustCompile(`layers\.(\d+)`) + moduleRe = regexp.MustCompile(`model\.layers\.\d+\.(.*?)\.lora_[ab]$`) +) + +// RenameMLXKey converts an MLX tensor key to PEFT format. +func RenameMLXKey(mlxKey string) string { + key := mlxKey + key = loraARe.ReplaceAllString(key, ".lora_A.default.weight") + key = loraBRe.ReplaceAllString(key, ".lora_B.default.weight") + key = "base_model.model." + key + return key +} + +// SafetensorsHeader represents the header of a safetensors file. +type SafetensorsHeader struct { + Metadata map[string]string `json:"__metadata__,omitempty"` + Tensors map[string]SafetensorsTensorInfo `json:"-"` +} + +// SafetensorsTensorInfo describes a tensor's dtype, shape, and data location. +type SafetensorsTensorInfo struct { + Dtype string `json:"dtype"` + Shape []int `json:"shape"` + DataOffsets [2]int `json:"data_offsets"` +} + +// ReadSafetensors reads a safetensors file and returns tensor info and raw data. +func ReadSafetensors(path string) (map[string]SafetensorsTensorInfo, []byte, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, fmt.Errorf("read file: %w", err) + } + + if len(data) < 8 { + return nil, nil, fmt.Errorf("file too small") + } + + headerSize := int(binary.LittleEndian.Uint64(data[:8])) + if 8+headerSize > len(data) { + return nil, nil, fmt.Errorf("invalid header size %d", headerSize) + } + + headerJSON := data[8 : 8+headerSize] + tensorData := data[8+headerSize:] + + var rawHeader map[string]json.RawMessage + if err := json.Unmarshal(headerJSON, &rawHeader); err != nil { + return nil, nil, fmt.Errorf("parse header: %w", err) + } + + tensors := make(map[string]SafetensorsTensorInfo) + for key, raw := range rawHeader { + if key == "__metadata__" { + continue + } + var info SafetensorsTensorInfo + if err := json.Unmarshal(raw, &info); err != nil { + return nil, nil, fmt.Errorf("parse tensor %s: %w", key, err) + } + tensors[key] = info + } + + return tensors, tensorData, nil +} + +// GetTensorData extracts raw bytes for a tensor from the data section. +func GetTensorData(info SafetensorsTensorInfo, allData []byte) []byte { + return allData[info.DataOffsets[0]:info.DataOffsets[1]] +} + +// TransposeFloat32 transposes a (rows, cols) float32 matrix to (cols, rows). +func TransposeFloat32(data []byte, rows, cols int) []byte { + if len(data) != rows*cols*4 { + return data + } + result := make([]byte, len(data)) + for r := range rows { + for c := range cols { + srcOff := (r*cols + c) * 4 + dstOff := (c*rows + r) * 4 + copy(result[dstOff:dstOff+4], data[srcOff:srcOff+4]) + } + } + return result +} + +// TransposeFloat16 transposes a (rows, cols) float16 matrix to (cols, rows). +func TransposeFloat16(data []byte, rows, cols int) []byte { + if len(data) != rows*cols*2 { + return data + } + result := make([]byte, len(data)) + for r := range rows { + for c := range cols { + srcOff := (r*cols + c) * 2 + dstOff := (c*rows + r) * 2 + copy(result[dstOff:dstOff+2], data[srcOff:srcOff+2]) + } + } + return result +} + +// TransposeBFloat16 transposes a (rows, cols) bfloat16 matrix to (cols, rows). +func TransposeBFloat16(data []byte, rows, cols int) []byte { + return TransposeFloat16(data, rows, cols) +} + +// WriteSafetensors writes tensors to a safetensors file. +func WriteSafetensors(path string, tensors map[string]SafetensorsTensorInfo, tensorData map[string][]byte) error { + keys := make([]string, 0, len(tensors)) + for k := range tensors { + keys = append(keys, k) + } + sort.Strings(keys) + + offset := 0 + updatedTensors := make(map[string]SafetensorsTensorInfo) + for _, k := range keys { + info := tensors[k] + data := tensorData[k] + info.DataOffsets = [2]int{offset, offset + len(data)} + updatedTensors[k] = info + offset += len(data) + } + + headerMap := make(map[string]interface{}) + for k, info := range updatedTensors { + headerMap[k] = info + } + + headerJSON, err := json.Marshal(headerMap) + if err != nil { + return fmt.Errorf("marshal header: %w", err) + } + + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("create %s: %w", path, err) + } + defer f.Close() + + headerSizeBuf := make([]byte, 8) + binary.LittleEndian.PutUint64(headerSizeBuf, uint64(len(headerJSON))) + + if _, err := f.Write(headerSizeBuf); err != nil { + return err + } + if _, err := f.Write(headerJSON); err != nil { + return err + } + + for _, k := range keys { + if _, err := f.Write(tensorData[k]); err != nil { + return err + } + } + + return nil +} + +// ConvertMLXtoPEFT converts an MLX LoRA adapter to HuggingFace PEFT format. +func ConvertMLXtoPEFT(safetensorsPath, configPath, outputDir, baseModelName string) error { + if err := os.MkdirAll(outputDir, 0755); err != nil { + return fmt.Errorf("create output dir: %w", err) + } + + tensors, tensorData, err := ReadSafetensors(safetensorsPath) + if err != nil { + return fmt.Errorf("read safetensors: %w", err) + } + log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath) + + peftTensors := make(map[string]SafetensorsTensorInfo) + peftData := make(map[string][]byte) + + for mlxKey, info := range tensors { + peftKey := RenameMLXKey(mlxKey) + data := GetTensorData(info, tensorData) + + if len(info.Shape) == 2 { + rows, cols := info.Shape[0], info.Shape[1] + switch info.Dtype { + case "F32": + data = TransposeFloat32(data, rows, cols) + case "F16": + data = TransposeFloat16(data, rows, cols) + case "BF16": + data = TransposeBFloat16(data, rows, cols) + } + info.Shape = []int{cols, rows} + } + + peftTensors[peftKey] = info + peftData[peftKey] = data + } + + outSafetensors := filepath.Join(outputDir, "adapter_model.safetensors") + if err := WriteSafetensors(outSafetensors, peftTensors, peftData); err != nil { + return fmt.Errorf("write safetensors: %w", err) + } + + cfgData, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + + var mlxConfig struct { + LoraParameters struct { + Rank int `json:"rank"` + Scale float64 `json:"scale"` + Dropout float64 `json:"dropout"` + } `json:"lora_parameters"` + } + if err := json.Unmarshal(cfgData, &mlxConfig); err != nil { + return fmt.Errorf("parse config: %w", err) + } + + rank := mlxConfig.LoraParameters.Rank + if rank == 0 { + rank = 8 + } + scale := mlxConfig.LoraParameters.Scale + if scale == 0 { + scale = 20.0 + } + + modules := make(map[string]bool) + layers := make(map[int]bool) + for k := range tensors { + if m := moduleRe.FindStringSubmatch(k); m != nil { + parts := strings.Split(m[1], ".") + modules[parts[len(parts)-1]] = true + } + if m := layerRe.FindStringSubmatch(k); m != nil { + n, _ := strconv.Atoi(m[1]) + layers[n] = true + } + } + + sortedModules := make([]string, 0, len(modules)) + for m := range modules { + sortedModules = append(sortedModules, m) + } + sort.Strings(sortedModules) + + sortedLayers := make([]int, 0, len(layers)) + for l := range layers { + sortedLayers = append(sortedLayers, l) + } + sort.Ints(sortedLayers) + + peftConfig := map[string]interface{}{ + "auto_mapping": nil, + "base_model_name_or_path": baseModelName, + "bias": "none", + "fan_in_fan_out": false, + "inference_mode": true, + "init_lora_weights": true, + "layers_pattern": nil, + "layers_to_transform": sortedLayers, + "lora_alpha": math.Round(scale * float64(rank)), + "lora_dropout": mlxConfig.LoraParameters.Dropout, + "modules_to_save": nil, + "peft_type": "LORA", + "r": rank, + "revision": nil, + "target_modules": sortedModules, + "task_type": "CAUSAL_LM", + } + + cfgJSON, err := json.MarshalIndent(peftConfig, "", " ") + if err != nil { + return fmt.Errorf("marshal peft config: %w", err) + } + + if err := os.WriteFile(filepath.Join(outputDir, "adapter_config.json"), cfgJSON, 0644); err != nil { + return fmt.Errorf("write adapter_config.json: %w", err) + } + + log.Printf("converted %d tensors, %d layers, target modules: %v", + len(peftTensors), len(sortedLayers), sortedModules) + + return nil +} diff --git a/pkg/ml/db.go b/pkg/ml/db.go new file mode 100644 index 0000000..95c6a14 --- /dev/null +++ b/pkg/ml/db.go @@ -0,0 +1,241 @@ +package ml + +import ( + "database/sql" + "fmt" + + _ "github.com/marcboeker/go-duckdb" +) + +// DB wraps a DuckDB connection. +type DB struct { + conn *sql.DB + path string +} + +// OpenDB opens a DuckDB database file in read-only mode to avoid locking +// issues with the Python pipeline. +func OpenDB(path string) (*DB, error) { + conn, err := sql.Open("duckdb", path+"?access_mode=READ_ONLY") + if err != nil { + return nil, fmt.Errorf("open duckdb %s: %w", path, err) + } + if err := conn.Ping(); err != nil { + conn.Close() + return nil, fmt.Errorf("ping duckdb %s: %w", path, err) + } + return &DB{conn: conn, path: path}, nil +} + +// OpenDBReadWrite opens a DuckDB database in read-write mode. +func OpenDBReadWrite(path string) (*DB, error) { + conn, err := sql.Open("duckdb", path) + if err != nil { + return nil, fmt.Errorf("open duckdb %s: %w", path, err) + } + if err := conn.Ping(); err != nil { + conn.Close() + return nil, fmt.Errorf("ping duckdb %s: %w", path, err) + } + return &DB{conn: conn, path: path}, nil +} + +// Close closes the database connection. +func (db *DB) Close() error { + return db.conn.Close() +} + +// GoldenSetRow represents one row from the golden_set table. +type GoldenSetRow struct { + Idx int + SeedID string + Domain string + Voice string + Prompt string + Response string + GenTime float64 + CharCount int +} + +// ExpansionPromptRow represents one row from the expansion_prompts table. +type ExpansionPromptRow struct { + Idx int64 + SeedID string + Region string + Domain string + Language string + Prompt string + PromptEn string + Priority int + Status string +} + +// QueryGoldenSet returns all golden set rows with responses >= minChars. +func (db *DB) QueryGoldenSet(minChars int) ([]GoldenSetRow, error) { + rows, err := db.conn.Query( + "SELECT idx, seed_id, domain, voice, prompt, response, gen_time, char_count "+ + "FROM golden_set WHERE char_count >= ? ORDER BY idx", + minChars, + ) + if err != nil { + return nil, fmt.Errorf("query golden_set: %w", err) + } + defer rows.Close() + + var result []GoldenSetRow + for rows.Next() { + var r GoldenSetRow + if err := rows.Scan(&r.Idx, &r.SeedID, &r.Domain, &r.Voice, + &r.Prompt, &r.Response, &r.GenTime, &r.CharCount); err != nil { + return nil, fmt.Errorf("scan golden_set row: %w", err) + } + result = append(result, r) + } + return result, rows.Err() +} + +// CountGoldenSet returns the total count of golden set rows. +func (db *DB) CountGoldenSet() (int, error) { + var count int + err := db.conn.QueryRow("SELECT COUNT(*) FROM golden_set").Scan(&count) + if err != nil { + return 0, fmt.Errorf("count golden_set: %w", err) + } + return count, nil +} + +// QueryExpansionPrompts returns expansion prompts filtered by status. +func (db *DB) QueryExpansionPrompts(status string, limit int) ([]ExpansionPromptRow, error) { + query := "SELECT idx, seed_id, region, domain, language, prompt, prompt_en, priority, status " + + "FROM expansion_prompts" + var args []interface{} + + if status != "" { + query += " WHERE status = ?" + args = append(args, status) + } + query += " ORDER BY priority, idx" + + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + rows, err := db.conn.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("query expansion_prompts: %w", err) + } + defer rows.Close() + + var result []ExpansionPromptRow + for rows.Next() { + var r ExpansionPromptRow + if err := rows.Scan(&r.Idx, &r.SeedID, &r.Region, &r.Domain, + &r.Language, &r.Prompt, &r.PromptEn, &r.Priority, &r.Status); err != nil { + return nil, fmt.Errorf("scan expansion_prompt row: %w", err) + } + result = append(result, r) + } + return result, rows.Err() +} + +// CountExpansionPrompts returns counts by status. +func (db *DB) CountExpansionPrompts() (total int, pending int, err error) { + err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts").Scan(&total) + if err != nil { + return 0, 0, fmt.Errorf("count expansion_prompts: %w", err) + } + err = db.conn.QueryRow("SELECT COUNT(*) FROM expansion_prompts WHERE status = 'pending'").Scan(&pending) + if err != nil { + return total, 0, fmt.Errorf("count pending expansion_prompts: %w", err) + } + return total, pending, nil +} + +// UpdateExpansionStatus updates the status of an expansion prompt by idx. +func (db *DB) UpdateExpansionStatus(idx int64, status string) error { + _, err := db.conn.Exec("UPDATE expansion_prompts SET status = ? WHERE idx = ?", status, idx) + if err != nil { + return fmt.Errorf("update expansion_prompt %d: %w", idx, err) + } + return nil +} + +// QueryRows executes an arbitrary SQL query and returns results as maps. +func (db *DB) QueryRows(query string, args ...interface{}) ([]map[string]interface{}, error) { + rows, err := db.conn.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("query: %w", err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, fmt.Errorf("columns: %w", err) + } + + var result []map[string]interface{} + for rows.Next() { + values := make([]interface{}, len(cols)) + ptrs := make([]interface{}, len(cols)) + for i := range values { + ptrs[i] = &values[i] + } + if err := rows.Scan(ptrs...); err != nil { + return nil, fmt.Errorf("scan: %w", err) + } + row := make(map[string]interface{}, len(cols)) + for i, col := range cols { + row[col] = values[i] + } + result = append(result, row) + } + return result, rows.Err() +} + +// EnsureScoringTables creates the scoring tables if they don't exist. +func (db *DB) EnsureScoringTables() { + db.conn.Exec(`CREATE TABLE IF NOT EXISTS checkpoint_scores ( + model TEXT, run_id TEXT, label TEXT, iteration INTEGER, + correct INTEGER, total INTEGER, accuracy DOUBLE, + scored_at TIMESTAMP DEFAULT current_timestamp, + PRIMARY KEY (run_id, label) + )`) + db.conn.Exec(`CREATE TABLE IF NOT EXISTS probe_results ( + model TEXT, run_id TEXT, label TEXT, probe_id TEXT, + passed BOOLEAN, response TEXT, iteration INTEGER, + scored_at TIMESTAMP DEFAULT current_timestamp, + PRIMARY KEY (run_id, label, probe_id) + )`) + db.conn.Exec(`CREATE TABLE IF NOT EXISTS scoring_results ( + model TEXT, prompt_id TEXT, suite TEXT, + dimension TEXT, score DOUBLE, + scored_at TIMESTAMP DEFAULT current_timestamp + )`) +} + +// WriteScoringResult writes a single scoring dimension result to DuckDB. +func (db *DB) WriteScoringResult(model, promptID, suite, dimension string, score float64) error { + _, err := db.conn.Exec( + `INSERT INTO scoring_results (model, prompt_id, suite, dimension, score) VALUES (?, ?, ?, ?, ?)`, + model, promptID, suite, dimension, score, + ) + return err +} + +// TableCounts returns row counts for all known tables. +func (db *DB) TableCounts() (map[string]int, error) { + tables := []string{"golden_set", "expansion_prompts", "seeds", "prompts", + "training_examples", "gemini_responses", "benchmark_questions", "benchmark_results", "validations", + "checkpoint_scores", "probe_results", "scoring_results"} + + counts := make(map[string]int) + for _, t := range tables { + var count int + err := db.conn.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", t)).Scan(&count) + if err != nil { + continue + } + counts[t] = count + } + return counts, nil +} diff --git a/pkg/ml/expand.go b/pkg/ml/expand.go new file mode 100644 index 0000000..a8c39ba --- /dev/null +++ b/pkg/ml/expand.go @@ -0,0 +1,153 @@ +package ml + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "time" +) + +// ExpandOutput is the JSONL output structure for expansion generation. +type ExpandOutput struct { + ID string `json:"id"` + Domain string `json:"domain,omitempty"` + Prompt string `json:"prompt"` + Response string `json:"response"` + Model string `json:"model"` + ElapsedSeconds float64 `json:"elapsed_seconds"` + Chars int `json:"chars"` +} + +// GetCompletedIDs queries InfluxDB for prompt IDs that have already been +// processed in the expansion_gen measurement. +func GetCompletedIDs(influx *InfluxClient) (map[string]bool, error) { + rows, err := influx.QuerySQL("SELECT DISTINCT seed_id FROM expansion_gen") + if err != nil { + return nil, fmt.Errorf("query expansion_gen: %w", err) + } + + ids := make(map[string]bool, len(rows)) + for _, row := range rows { + id := strVal(row, "seed_id") + if id != "" { + ids[id] = true + } + } + + return ids, nil +} + +// ExpandPrompts generates responses for expansion prompts using the given +// backend and reports progress to InfluxDB. Already-completed prompts (per +// InfluxDB) are skipped. API errors for individual prompts are logged and +// skipped. InfluxDB reporting is best-effort. +func ExpandPrompts(ctx context.Context, backend Backend, influx *InfluxClient, prompts []Response, + modelName, worker, outputDir string, dryRun bool, limit int) error { + + remaining := prompts + + // Check InfluxDB for already-completed IDs. + completed, err := GetCompletedIDs(influx) + if err != nil { + log.Printf("warning: could not check completed IDs: %v", err) + } else { + remaining = nil + for _, p := range prompts { + if !completed[p.ID] { + remaining = append(remaining, p) + } + } + + skipped := len(prompts) - len(remaining) + if skipped > 0 { + log.Printf("skipping %d already-completed prompts, %d remaining", skipped, len(remaining)) + } + } + + if limit > 0 && limit < len(remaining) { + remaining = remaining[:limit] + } + + if len(remaining) == 0 { + log.Println("all prompts already completed, nothing to do") + return nil + } + + if dryRun { + log.Printf("dry-run: would process %d prompts with model %s (worker: %s)", len(remaining), modelName, worker) + for i, p := range remaining { + if i >= 10 { + log.Printf(" ... and %d more", len(remaining)-10) + break + } + log.Printf(" %s (domain: %s)", p.ID, p.Domain) + } + return nil + } + + outputPath := filepath.Join(outputDir, fmt.Sprintf("expand-%s.jsonl", worker)) + f, err := os.OpenFile(outputPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("open output file: %w", err) + } + defer f.Close() + + total := len(remaining) + completedCount := 0 + + for idx, p := range remaining { + start := time.Now() + response, err := backend.Generate(ctx, p.Prompt, GenOpts{Temperature: 0.7, MaxTokens: 2048}) + elapsed := time.Since(start).Seconds() + + if err != nil { + log.Printf("[%d/%d] id=%s ERROR: %v", idx+1, total, p.ID, err) + continue + } + + chars := len(response) + completedCount++ + + out := ExpandOutput{ + ID: p.ID, + Domain: p.Domain, + Prompt: p.Prompt, + Response: response, + Model: modelName, + ElapsedSeconds: elapsed, + Chars: chars, + } + + line, err := json.Marshal(out) + if err != nil { + log.Printf("[%d/%d] id=%s marshal error: %v", idx+1, total, p.ID, err) + continue + } + + if _, err := f.Write(append(line, '\n')); err != nil { + log.Printf("[%d/%d] id=%s write error: %v", idx+1, total, p.ID, err) + continue + } + + genLine := fmt.Sprintf("expansion_gen,i=%d,w=%s,d=%s seed_id=\"%s\",gen_time=%f,chars=%di,model=\"%s\"", + idx, EscapeLp(worker), EscapeLp(p.Domain), + p.ID, elapsed, chars, modelName) + + pct := float64(completedCount) / float64(total) * 100.0 + progressLine := fmt.Sprintf("expansion_progress,worker=%s completed=%di,target=%di,pct=%f", + EscapeLp(worker), completedCount, total, pct) + + if writeErr := influx.WriteLp([]string{genLine, progressLine}); writeErr != nil { + log.Printf("[%d/%d] id=%s influx write error: %v", idx+1, total, p.ID, writeErr) + } + + log.Printf("[%d/%d] id=%s chars=%d time=%.1fs", idx+1, total, p.ID, chars, elapsed) + } + + log.Printf("expand complete: %d/%d prompts generated, output: %s", completedCount, total, outputPath) + + return nil +} diff --git a/pkg/ml/export.go b/pkg/ml/export.go new file mode 100644 index 0000000..9313231 --- /dev/null +++ b/pkg/ml/export.go @@ -0,0 +1,112 @@ +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "math/rand" + "os" + "strings" +) + +// ChatMessage is a single message in the chat training format. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// TrainingExample is a single training example in chat JSONL format. +type TrainingExample struct { + Messages []ChatMessage `json:"messages"` +} + +// ValidatePercentages checks that train+valid+test percentages sum to 100 +// and that none are negative. +func ValidatePercentages(trainPct, validPct, testPct int) error { + if trainPct < 0 || validPct < 0 || testPct < 0 { + return fmt.Errorf("percentages must be non-negative: train=%d, valid=%d, test=%d", trainPct, validPct, testPct) + } + sum := trainPct + validPct + testPct + if sum != 100 { + return fmt.Errorf("percentages must sum to 100, got %d (train=%d + valid=%d + test=%d)", sum, trainPct, validPct, testPct) + } + return nil +} + +// FilterResponses removes responses with empty content, "ERROR:" prefix, +// or response length < 50 characters. +func FilterResponses(responses []Response) []Response { + var filtered []Response + for _, r := range responses { + if r.Response == "" { + continue + } + if strings.HasPrefix(r.Response, "ERROR:") { + continue + } + if len(r.Response) < 50 { + continue + } + filtered = append(filtered, r) + } + return filtered +} + +// SplitData shuffles responses with a deterministic seed and splits them +// into train, valid, and test sets by the given percentages. +func SplitData(responses []Response, trainPct, validPct, testPct int, seed int64) (train, valid, test []Response) { + shuffled := make([]Response, len(responses)) + copy(shuffled, responses) + + rng := rand.New(rand.NewSource(seed)) + rng.Shuffle(len(shuffled), func(i, j int) { + shuffled[i], shuffled[j] = shuffled[j], shuffled[i] + }) + + n := len(shuffled) + trainN := n * trainPct / 100 + validN := n * validPct / 100 + _ = testPct + + train = shuffled[:trainN] + valid = shuffled[trainN : trainN+validN] + test = shuffled[trainN+validN:] + + return train, valid, test +} + +// WriteTrainingJSONL writes responses in chat JSONL format suitable for +// MLX LoRA fine-tuning. +func WriteTrainingJSONL(path string, responses []Response) error { + f, err := os.Create(path) + if err != nil { + return fmt.Errorf("create %s: %w", path, err) + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + for _, r := range responses { + example := TrainingExample{ + Messages: []ChatMessage{ + {Role: "user", Content: r.Prompt}, + {Role: "assistant", Content: r.Response}, + }, + } + + data, err := json.Marshal(example) + if err != nil { + return fmt.Errorf("marshal example: %w", err) + } + + if _, err := w.Write(data); err != nil { + return fmt.Errorf("write line: %w", err) + } + if _, err := w.WriteString("\n"); err != nil { + return fmt.Errorf("write newline: %w", err) + } + } + + return nil +} diff --git a/pkg/ml/gguf.go b/pkg/ml/gguf.go new file mode 100644 index 0000000..3155a55 --- /dev/null +++ b/pkg/ml/gguf.go @@ -0,0 +1,369 @@ +package ml + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "math" + "os" + "regexp" + "sort" + "strconv" + "strings" +) + +// GGUF format constants. +const ( + ggufMagic = 0x46554747 // "GGUF" little-endian + ggufVersion = 3 + ggufAlignment = 32 +) + +// GGUF metadata value types. +const ( + ggufTypeUint32 = 4 + ggufTypeFloat32 = 6 + ggufTypeString = 8 +) + +// GGML tensor data types. +const ( + ggmlTypeF32 = 0 + ggmlTypeF16 = 1 + ggmlTypeBF16 = 30 +) + +// ggufMetadata is a key-value pair in the GGUF header. +type ggufMetadata struct { + key string + valueType uint32 + value interface{} // string, uint32, or float32 +} + +// ggufTensor describes a tensor in the GGUF file. +type ggufTensor struct { + name string + dims []uint64 + dtype uint32 // ggmlType* + data []byte +} + +// gemma3ModuleMap maps HuggingFace module names to GGUF tensor names. +var gemma3ModuleMap = map[string]string{ + "self_attn.q_proj": "attn_q", + "self_attn.k_proj": "attn_k", + "self_attn.v_proj": "attn_v", + "self_attn.o_proj": "attn_output", + "mlp.gate_proj": "ffn_gate", + "mlp.up_proj": "ffn_up", + "mlp.down_proj": "ffn_down", +} + +var mlxLoraKeyRe = regexp.MustCompile(`^model\.layers\.(\d+)\.(.*?)\.(lora_[ab])$`) + +// MLXTensorToGGUF converts an MLX LoRA tensor name to GGUF LoRA tensor name. +// Input: "model.layers.0.self_attn.q_proj.lora_a" +// Output: "blk.0.attn_q.weight.lora_a" +func MLXTensorToGGUF(mlxName string) (string, error) { + m := mlxLoraKeyRe.FindStringSubmatch(mlxName) + if m == nil { + return "", fmt.Errorf("unrecognised MLX LoRA key: %s", mlxName) + } + + layerNum := m[1] + module := m[2] + loraSuffix := m[3] + + ggufModule, ok := gemma3ModuleMap[module] + if !ok { + return "", fmt.Errorf("unknown module %q in %s", module, mlxName) + } + + return fmt.Sprintf("blk.%s.%s.weight.%s", layerNum, ggufModule, loraSuffix), nil +} + +// SafetensorsDtypeToGGML maps safetensors dtype strings to GGML types. +func SafetensorsDtypeToGGML(dtype string) (uint32, error) { + switch dtype { + case "F32": + return ggmlTypeF32, nil + case "F16": + return ggmlTypeF16, nil + case "BF16": + return ggmlTypeBF16, nil + default: + return 0, fmt.Errorf("unsupported dtype %q for GGUF", dtype) + } +} + +// ConvertMLXtoGGUFLoRA converts an MLX LoRA adapter to GGUF LoRA format. +func ConvertMLXtoGGUFLoRA(safetensorsPath, configPath, outputPath, architecture string) error { + cfgData, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("read config: %w", err) + } + + var mlxConfig struct { + LoraParameters struct { + Rank int `json:"rank"` + Scale float64 `json:"scale"` + } `json:"lora_parameters"` + } + if err := json.Unmarshal(cfgData, &mlxConfig); err != nil { + return fmt.Errorf("parse config: %w", err) + } + + rank := mlxConfig.LoraParameters.Rank + if rank == 0 { + rank = 8 + } + scale := mlxConfig.LoraParameters.Scale + if scale == 0 { + scale = 20.0 + } + loraAlpha := float32(math.Round(scale * float64(rank))) + + tensors, tensorData, err := ReadSafetensors(safetensorsPath) + if err != nil { + return fmt.Errorf("read safetensors: %w", err) + } + log.Printf("loaded %d tensors from %s", len(tensors), safetensorsPath) + + var ggufTensors []ggufTensor + for mlxKey, info := range tensors { + ggufName, err := MLXTensorToGGUF(mlxKey) + if err != nil { + return err + } + + ggmlType, err := SafetensorsDtypeToGGML(info.Dtype) + if err != nil { + return fmt.Errorf("tensor %s: %w", mlxKey, err) + } + + data := GetTensorData(info, tensorData) + + if len(info.Shape) == 2 { + rows, cols := info.Shape[0], info.Shape[1] + switch info.Dtype { + case "F32": + data = TransposeFloat32(data, rows, cols) + case "F16": + data = TransposeFloat16(data, rows, cols) + case "BF16": + data = TransposeBFloat16(data, rows, cols) + } + ggufTensors = append(ggufTensors, ggufTensor{ + name: ggufName, + dims: []uint64{uint64(rows), uint64(cols)}, + dtype: ggmlType, + data: data, + }) + } else { + dims := make([]uint64, len(info.Shape)) + for i, s := range info.Shape { + dims[i] = uint64(s) + } + ggufTensors = append(ggufTensors, ggufTensor{ + name: ggufName, + dims: dims, + dtype: ggmlType, + data: data, + }) + } + } + + sort.Slice(ggufTensors, func(i, j int) bool { + return ggufTensors[i].name < ggufTensors[j].name + }) + + metadata := []ggufMetadata{ + {key: "general.type", valueType: ggufTypeString, value: "adapter"}, + {key: "general.architecture", valueType: ggufTypeString, value: architecture}, + {key: "adapter.type", valueType: ggufTypeString, value: "lora"}, + {key: "adapter.lora.alpha", valueType: ggufTypeFloat32, value: loraAlpha}, + } + + if err := writeGGUF(outputPath, metadata, ggufTensors); err != nil { + return fmt.Errorf("write GGUF: %w", err) + } + + log.Printf("wrote GGUF LoRA: %s (%d tensors, alpha=%.0f)", outputPath, len(ggufTensors), loraAlpha) + return nil +} + +// writeGGUF writes a GGUF v3 file. +func writeGGUF(path string, metadata []ggufMetadata, tensors []ggufTensor) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + w := &ggufWriter{f: f} + + w.writeUint32(ggufMagic) + w.writeUint32(ggufVersion) + w.writeUint64(uint64(len(tensors))) + w.writeUint64(uint64(len(metadata))) + + for _, kv := range metadata { + w.writeString(kv.key) + w.writeUint32(kv.valueType) + switch kv.valueType { + case ggufTypeString: + w.writeString(kv.value.(string)) + case ggufTypeUint32: + w.writeUint32(kv.value.(uint32)) + case ggufTypeFloat32: + w.writeFloat32(kv.value.(float32)) + } + } + + dataOffset := uint64(0) + for _, t := range tensors { + w.writeString(t.name) + w.writeUint32(uint32(len(t.dims))) + for _, d := range t.dims { + w.writeUint64(d) + } + w.writeUint32(t.dtype) + w.writeUint64(dataOffset) + + dataOffset += uint64(len(t.data)) + if rem := dataOffset % ggufAlignment; rem != 0 { + dataOffset += ggufAlignment - rem + } + } + + pos := w.pos + if rem := pos % ggufAlignment; rem != 0 { + pad := ggufAlignment - rem + w.writeBytes(make([]byte, pad)) + } + + for _, t := range tensors { + w.writeBytes(t.data) + if rem := uint64(len(t.data)) % ggufAlignment; rem != 0 { + w.writeBytes(make([]byte, ggufAlignment-rem)) + } + } + + return w.err +} + +// ggufWriter tracks position and accumulates errors. +type ggufWriter struct { + f *os.File + pos uint64 + err error +} + +func (w *ggufWriter) writeBytes(b []byte) { + if w.err != nil { + return + } + n, err := w.f.Write(b) + w.pos += uint64(n) + if err != nil { + w.err = err + } +} + +func (w *ggufWriter) writeUint32(v uint32) { + b := make([]byte, 4) + binary.LittleEndian.PutUint32(b, v) + w.writeBytes(b) +} + +func (w *ggufWriter) writeUint64(v uint64) { + b := make([]byte, 8) + binary.LittleEndian.PutUint64(b, v) + w.writeBytes(b) +} + +func (w *ggufWriter) writeFloat32(v float32) { + w.writeUint32(math.Float32bits(v)) +} + +func (w *ggufWriter) writeString(s string) { + w.writeUint64(uint64(len(s))) + w.writeBytes([]byte(s)) +} + +// DetectArchFromConfig tries to infer the model architecture from adapter_config.json. +func DetectArchFromConfig(configPath string) string { + data, err := os.ReadFile(configPath) + if err != nil { + return "gemma3" + } + var cfg struct { + LoraParameters struct { + Rank int `json:"rank"` + } `json:"lora_parameters"` + } + json.Unmarshal(data, &cfg) + return "gemma3" +} + +// ArchitectureGGUFMap maps model tags to GGUF architecture names. +var ArchitectureGGUFMap = map[string]string{ + "gemma-3-1b": "gemma3", + "gemma-3-4b": "gemma3", + "gemma-3-12b": "gemma3", + "gemma-3-27b": "gemma3", +} + +// ModelTagToGGUFArch returns the GGUF architecture for a model tag. +func ModelTagToGGUFArch(modelTag string) string { + if arch, ok := ArchitectureGGUFMap[modelTag]; ok { + return arch + } + return "gemma3" +} + +// GGUFModelBlobPath returns the path to the GGUF model blob in Ollama's store. +func GGUFModelBlobPath(ollamaModelsDir, modelName string) (string, error) { + parts := strings.SplitN(modelName, ":", 2) + family := parts[0] + tag := "latest" + if len(parts) > 1 { + tag = parts[1] + } + + manifestPath := fmt.Sprintf("%s/manifests/registry.ollama.ai/library/%s/%s", ollamaModelsDir, family, tag) + data, err := os.ReadFile(manifestPath) + if err != nil { + return "", fmt.Errorf("read manifest %s: %w", manifestPath, err) + } + + var manifest struct { + Layers []struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + } `json:"layers"` + } + if err := json.Unmarshal(data, &manifest); err != nil { + return "", fmt.Errorf("parse manifest: %w", err) + } + + for _, layer := range manifest.Layers { + if layer.MediaType == "application/vnd.ollama.image.model" { + blobName := strings.Replace(layer.Digest, ":", "-", 1) + return fmt.Sprintf("%s/blobs/%s", ollamaModelsDir, blobName), nil + } + } + + return "", fmt.Errorf("no model layer found in manifest for %s", modelName) +} + +// ParseLayerFromTensorName extracts the layer number from a GGUF tensor name. +func ParseLayerFromTensorName(name string) (int, error) { + re := regexp.MustCompile(`blk\.(\d+)\.`) + m := re.FindStringSubmatch(name) + if m == nil { + return 0, fmt.Errorf("no layer number in %s", name) + } + return strconv.Atoi(m[1]) +} diff --git a/pkg/ml/influx.go b/pkg/ml/influx.go new file mode 100644 index 0000000..6ec9c1b --- /dev/null +++ b/pkg/ml/influx.go @@ -0,0 +1,132 @@ +package ml + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" +) + +// InfluxClient talks to an InfluxDB v3 instance. +type InfluxClient struct { + url string + db string + token string +} + +// NewInfluxClient creates an InfluxClient for the given URL and database. +// Reads token from INFLUX_TOKEN env var first, then ~/.influx_token file. +// If url is empty, defaults to "http://10.69.69.165:8181". +// If db is empty, defaults to "training". +func NewInfluxClient(url, db string) *InfluxClient { + if url == "" { + url = "http://10.69.69.165:8181" + } + if db == "" { + db = "training" + } + + token := os.Getenv("INFLUX_TOKEN") + if token == "" { + home, err := os.UserHomeDir() + if err == nil { + data, err := os.ReadFile(filepath.Join(home, ".influx_token")) + if err == nil { + token = strings.TrimSpace(string(data)) + } + } + } + + return &InfluxClient{ + url: url, + db: db, + token: token, + } +} + +// WriteLp writes line protocol data to InfluxDB. +func (c *InfluxClient) WriteLp(lines []string) error { + body := strings.Join(lines, "\n") + + url := fmt.Sprintf("%s/api/v3/write_lp?db=%s", c.url, c.db) + + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader(body)) + if err != nil { + return fmt.Errorf("create write request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("write request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("write failed %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// QuerySQL runs a SQL query against InfluxDB and returns the result rows. +func (c *InfluxClient) QuerySQL(sql string) ([]map[string]interface{}, error) { + reqBody := map[string]string{ + "db": c.db, + "q": sql, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal query request: %w", err) + } + + url := c.url + "/api/v3/query_sql" + + req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(jsonBody)) + if err != nil { + return nil, fmt.Errorf("create query request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+c.token) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("query request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read query response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("query failed %d: %s", resp.StatusCode, string(respBody)) + } + + var rows []map[string]interface{} + if err := json.Unmarshal(respBody, &rows); err != nil { + return nil, fmt.Errorf("unmarshal query response: %w", err) + } + + return rows, nil +} + +// EscapeLp escapes spaces, commas, and equals signs for InfluxDB line protocol +// tag values. +func EscapeLp(s string) string { + s = strings.ReplaceAll(s, `,`, `\,`) + s = strings.ReplaceAll(s, `=`, `\=`) + s = strings.ReplaceAll(s, ` `, `\ `) + return s +} diff --git a/pkg/ml/ollama.go b/pkg/ml/ollama.go new file mode 100644 index 0000000..66069f8 --- /dev/null +++ b/pkg/ml/ollama.go @@ -0,0 +1,152 @@ +package ml + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "time" +) + +// OllamaBaseModelMap maps model tags to Ollama model names. +var OllamaBaseModelMap = map[string]string{ + "gemma-3-1b": "gemma3:1b", + "gemma-3-4b": "gemma3:4b", + "gemma-3-12b": "gemma3:12b", + "gemma-3-27b": "gemma3:27b", +} + +// HFBaseModelMap maps model tags to HuggingFace model IDs. +var HFBaseModelMap = map[string]string{ + "gemma-3-1b": "google/gemma-3-1b-it", + "gemma-3-4b": "google/gemma-3-4b-it", + "gemma-3-12b": "google/gemma-3-12b-it", + "gemma-3-27b": "google/gemma-3-27b-it", +} + +// ollamaUploadBlob uploads a local file to Ollama's blob store. +// Returns the sha256 digest string (e.g. "sha256:abc123..."). +func ollamaUploadBlob(ollamaURL, filePath string) (string, error) { + data, err := os.ReadFile(filePath) + if err != nil { + return "", fmt.Errorf("read %s: %w", filePath, err) + } + + hash := sha256.Sum256(data) + digest := "sha256:" + hex.EncodeToString(hash[:]) + + headReq, _ := http.NewRequest(http.MethodHead, ollamaURL+"/api/blobs/"+digest, nil) + client := &http.Client{Timeout: 5 * time.Minute} + headResp, err := client.Do(headReq) + if err == nil && headResp.StatusCode == http.StatusOK { + headResp.Body.Close() + return digest, nil + } + if headResp != nil { + headResp.Body.Close() + } + + req, err := http.NewRequest(http.MethodPost, ollamaURL+"/api/blobs/"+digest, bytes.NewReader(data)) + if err != nil { + return "", fmt.Errorf("blob request: %w", err) + } + req.Header.Set("Content-Type", "application/octet-stream") + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("blob upload: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("blob upload HTTP %d: %s", resp.StatusCode, string(body)) + } + return digest, nil +} + +// OllamaCreateModel creates a temporary Ollama model with a LoRA adapter. +// peftDir is a local directory containing adapter_model.safetensors and adapter_config.json. +func OllamaCreateModel(ollamaURL, modelName, baseModel, peftDir string) error { + sfPath := peftDir + "/adapter_model.safetensors" + cfgPath := peftDir + "/adapter_config.json" + + sfDigest, err := ollamaUploadBlob(ollamaURL, sfPath) + if err != nil { + return fmt.Errorf("upload adapter safetensors: %w", err) + } + + cfgDigest, err := ollamaUploadBlob(ollamaURL, cfgPath) + if err != nil { + return fmt.Errorf("upload adapter config: %w", err) + } + + reqBody, _ := json.Marshal(map[string]interface{}{ + "model": modelName, + "from": baseModel, + "adapters": map[string]string{ + "adapter_model.safetensors": sfDigest, + "adapter_config.json": cfgDigest, + }, + }) + + client := &http.Client{Timeout: 10 * time.Minute} + resp, err := client.Post(ollamaURL+"/api/create", "application/json", bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("ollama create: %w", err) + } + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + for decoder.More() { + var status struct { + Status string `json:"status"` + Error string `json:"error"` + } + if err := decoder.Decode(&status); err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("ollama create decode: %w", err) + } + if status.Error != "" { + return fmt.Errorf("ollama create: %s", status.Error) + } + if status.Status == "success" { + return nil + } + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("ollama create: HTTP %d", resp.StatusCode) + } + return nil +} + +// OllamaDeleteModel removes a temporary Ollama model. +func OllamaDeleteModel(ollamaURL, modelName string) error { + body, _ := json.Marshal(map[string]string{"model": modelName}) + + req, err := http.NewRequest(http.MethodDelete, ollamaURL+"/api/delete", bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("ollama delete request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("ollama delete: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("ollama delete %d: %s", resp.StatusCode, string(respBody)) + } + return nil +} diff --git a/pkg/ml/parquet.go b/pkg/ml/parquet.go new file mode 100644 index 0000000..13d8a14 --- /dev/null +++ b/pkg/ml/parquet.go @@ -0,0 +1,137 @@ +package ml + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/parquet-go/parquet-go" +) + +// ParquetRow is the schema for exported Parquet files. +type ParquetRow struct { + Prompt string `parquet:"prompt"` + Response string `parquet:"response"` + System string `parquet:"system"` + Messages string `parquet:"messages"` +} + +// ExportParquet reads JSONL training splits (train.jsonl, valid.jsonl, test.jsonl) +// from trainingDir and writes Parquet files with snappy compression to outputDir. +// Returns total rows exported. +func ExportParquet(trainingDir, outputDir string) (int, error) { + if outputDir == "" { + outputDir = filepath.Join(trainingDir, "parquet") + } + if err := os.MkdirAll(outputDir, 0755); err != nil { + return 0, fmt.Errorf("create output dir: %w", err) + } + + total := 0 + for _, split := range []string{"train", "valid", "test"} { + jsonlPath := filepath.Join(trainingDir, split+".jsonl") + if _, err := os.Stat(jsonlPath); os.IsNotExist(err) { + continue + } + + n, err := ExportSplitParquet(jsonlPath, outputDir, split) + if err != nil { + return total, fmt.Errorf("export %s: %w", split, err) + } + total += n + } + + return total, nil +} + +// ExportSplitParquet reads a chat JSONL file and writes a Parquet file for the +// given split name. Returns the number of rows written. +func ExportSplitParquet(jsonlPath, outputDir, split string) (int, error) { + f, err := os.Open(jsonlPath) + if err != nil { + return 0, fmt.Errorf("open %s: %w", jsonlPath, err) + } + defer f.Close() + + var rows []ParquetRow + scanner := bufio.NewScanner(f) + scanner.Buffer(make([]byte, 1024*1024), 1024*1024) + + for scanner.Scan() { + text := strings.TrimSpace(scanner.Text()) + if text == "" { + continue + } + + var data struct { + Messages []ChatMessage `json:"messages"` + } + if err := json.Unmarshal([]byte(text), &data); err != nil { + continue + } + + var prompt, response, system string + for _, m := range data.Messages { + switch m.Role { + case "user": + if prompt == "" { + prompt = m.Content + } + case "assistant": + if response == "" { + response = m.Content + } + case "system": + if system == "" { + system = m.Content + } + } + } + + msgsJSON, _ := json.Marshal(data.Messages) + rows = append(rows, ParquetRow{ + Prompt: prompt, + Response: response, + System: system, + Messages: string(msgsJSON), + }) + } + + if err := scanner.Err(); err != nil { + return 0, fmt.Errorf("scan %s: %w", jsonlPath, err) + } + + if len(rows) == 0 { + return 0, nil + } + + outPath := filepath.Join(outputDir, split+".parquet") + + out, err := os.Create(outPath) + if err != nil { + return 0, fmt.Errorf("create %s: %w", outPath, err) + } + + writer := parquet.NewGenericWriter[ParquetRow](out, + parquet.Compression(&parquet.Snappy), + ) + + if _, err := writer.Write(rows); err != nil { + out.Close() + return 0, fmt.Errorf("write parquet rows: %w", err) + } + + if err := writer.Close(); err != nil { + out.Close() + return 0, fmt.Errorf("close parquet writer: %w", err) + } + + if err := out.Close(); err != nil { + return 0, fmt.Errorf("close file: %w", err) + } + + return len(rows), nil +} diff --git a/pkg/ml/status.go b/pkg/ml/status.go new file mode 100644 index 0000000..d61a0a2 --- /dev/null +++ b/pkg/ml/status.go @@ -0,0 +1,212 @@ +package ml + +import ( + "fmt" + "io" + "sort" +) + +// trainingRow holds deduplicated training status + loss for a single model. +type trainingRow struct { + model string + status string + iteration int + totalIters int + pct float64 + loss float64 + hasLoss bool +} + +// genRow holds deduplicated generation progress for a single worker. +type genRow struct { + worker string + completed int + target int + pct float64 +} + +// PrintStatus queries InfluxDB for training and generation progress and writes +// a formatted summary to w. +func PrintStatus(influx *InfluxClient, w io.Writer) error { + statusRows, err := influx.QuerySQL( + "SELECT model, run_id, status, iteration, total_iters, pct FROM training_status ORDER BY time DESC LIMIT 10", + ) + if err != nil { + statusRows = nil + } + + lossRows, err := influx.QuerySQL( + "SELECT model, loss_type, loss, iteration, tokens_per_sec FROM training_loss WHERE loss_type = 'train' ORDER BY time DESC LIMIT 10", + ) + if err != nil { + lossRows = nil + } + + goldenRows, err := influx.QuerySQL( + "SELECT worker, completed, target, pct FROM golden_gen_progress ORDER BY time DESC LIMIT 5", + ) + if err != nil { + goldenRows = nil + } + + expansionRows, err := influx.QuerySQL( + "SELECT worker, completed, target, pct FROM expansion_progress ORDER BY time DESC LIMIT 5", + ) + if err != nil { + expansionRows = nil + } + + training := dedupeTraining(statusRows, lossRows) + golden := dedupeGeneration(goldenRows) + expansion := dedupeGeneration(expansionRows) + + fmt.Fprintln(w, "Training:") + if len(training) == 0 { + fmt.Fprintln(w, " (no data)") + } else { + for _, tr := range training { + progress := fmt.Sprintf("%d/%d", tr.iteration, tr.totalIters) + pct := fmt.Sprintf("%.1f%%", tr.pct) + if tr.hasLoss { + fmt.Fprintf(w, " %-13s %-9s %9s %7s loss=%.3f\n", + tr.model, tr.status, progress, pct, tr.loss) + } else { + fmt.Fprintf(w, " %-13s %-9s %9s %7s\n", + tr.model, tr.status, progress, pct) + } + } + } + + fmt.Fprintln(w) + fmt.Fprintln(w, "Generation:") + + hasGenData := false + + if len(golden) > 0 { + hasGenData = true + for _, g := range golden { + progress := fmt.Sprintf("%d/%d", g.completed, g.target) + pct := fmt.Sprintf("%.1f%%", g.pct) + fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "golden", progress, pct, g.worker) + } + } + + if len(expansion) > 0 { + hasGenData = true + for _, g := range expansion { + progress := fmt.Sprintf("%d/%d", g.completed, g.target) + pct := fmt.Sprintf("%.1f%%", g.pct) + fmt.Fprintf(w, " %-13s %11s %7s (%s)\n", "expansion", progress, pct, g.worker) + } + } + + if !hasGenData { + fmt.Fprintln(w, " (no data)") + } + + return nil +} + +// dedupeTraining merges training status and loss rows, keeping only the first +// (latest) row per model. +func dedupeTraining(statusRows, lossRows []map[string]interface{}) []trainingRow { + lossMap := make(map[string]float64) + lossSeenMap := make(map[string]bool) + for _, row := range lossRows { + model := strVal(row, "model") + if model == "" || lossSeenMap[model] { + continue + } + lossSeenMap[model] = true + lossMap[model] = floatVal(row, "loss") + } + + seen := make(map[string]bool) + var rows []trainingRow + for _, row := range statusRows { + model := strVal(row, "model") + if model == "" || seen[model] { + continue + } + seen[model] = true + + tr := trainingRow{ + model: model, + status: strVal(row, "status"), + iteration: intVal(row, "iteration"), + totalIters: intVal(row, "total_iters"), + pct: floatVal(row, "pct"), + } + + if loss, ok := lossMap[model]; ok { + tr.loss = loss + tr.hasLoss = true + } + + rows = append(rows, tr) + } + + sort.Slice(rows, func(i, j int) bool { + return rows[i].model < rows[j].model + }) + + return rows +} + +// dedupeGeneration deduplicates generation progress rows by worker. +func dedupeGeneration(rows []map[string]interface{}) []genRow { + seen := make(map[string]bool) + var result []genRow + for _, row := range rows { + worker := strVal(row, "worker") + if worker == "" || seen[worker] { + continue + } + seen[worker] = true + + result = append(result, genRow{ + worker: worker, + completed: intVal(row, "completed"), + target: intVal(row, "target"), + pct: floatVal(row, "pct"), + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].worker < result[j].worker + }) + + return result +} + +// strVal extracts a string value from a row map. +func strVal(row map[string]interface{}, key string) string { + v, ok := row[key] + if !ok { + return "" + } + s, ok := v.(string) + if !ok { + return "" + } + return s +} + +// floatVal extracts a float64 value from a row map. +func floatVal(row map[string]interface{}, key string) float64 { + v, ok := row[key] + if !ok { + return 0 + } + f, ok := v.(float64) + if !ok { + return 0 + } + return f +} + +// intVal extracts an integer value from a row map. InfluxDB JSON returns all +// numbers as float64, so this truncates to int. +func intVal(row map[string]interface{}, key string) int { + return int(floatVal(row, key)) +} diff --git a/pkg/ml/worker.go b/pkg/ml/worker.go new file mode 100644 index 0000000..ac0678d --- /dev/null +++ b/pkg/ml/worker.go @@ -0,0 +1,403 @@ +package ml + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "runtime" + "time" +) + +// WorkerConfig holds the worker's runtime configuration. +type WorkerConfig struct { + APIBase string + WorkerID string + Name string + APIKey string + GPUType string + VRAMGb int + Languages []string + Models []string + InferURL string + TaskType string + BatchSize int + PollInterval time.Duration + OneShot bool + DryRun bool +} + +// APITask represents a task from the LEM API. +type APITask struct { + ID int `json:"id"` + TaskType string `json:"task_type"` + Status string `json:"status"` + Language string `json:"language"` + Domain string `json:"domain"` + ModelName string `json:"model_name"` + PromptID string `json:"prompt_id"` + PromptText string `json:"prompt_text"` + Config *struct { + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + } `json:"config"` + Priority int `json:"priority"` +} + +// RunWorkerLoop is the main worker loop that polls for tasks and processes them. +func RunWorkerLoop(cfg *WorkerConfig) { + log.Printf("LEM Worker starting") + log.Printf(" ID: %s", cfg.WorkerID) + log.Printf(" Name: %s", cfg.Name) + log.Printf(" API: %s", cfg.APIBase) + log.Printf(" Infer: %s", cfg.InferURL) + log.Printf(" GPU: %s (%d GB)", cfg.GPUType, cfg.VRAMGb) + log.Printf(" Langs: %v", cfg.Languages) + log.Printf(" Models: %v", cfg.Models) + log.Printf(" Batch: %d", cfg.BatchSize) + log.Printf(" Dry-run: %v", cfg.DryRun) + + if err := workerRegister(cfg); err != nil { + log.Fatalf("Registration failed: %v", err) + } + log.Println("Registered with LEM API") + + for { + processed := workerPoll(cfg) + + if cfg.OneShot { + log.Printf("One-shot mode: processed %d tasks, exiting", processed) + return + } + + if processed == 0 { + log.Printf("No tasks available, sleeping %v", cfg.PollInterval) + time.Sleep(cfg.PollInterval) + } + + workerHeartbeat(cfg) + } +} + +func workerRegister(cfg *WorkerConfig) error { + body := map[string]interface{}{ + "worker_id": cfg.WorkerID, + "name": cfg.Name, + "version": "0.1.0", + "os": runtime.GOOS, + "arch": runtime.GOARCH, + } + if cfg.GPUType != "" { + body["gpu_type"] = cfg.GPUType + } + if cfg.VRAMGb > 0 { + body["vram_gb"] = cfg.VRAMGb + } + if len(cfg.Languages) > 0 { + body["languages"] = cfg.Languages + } + if len(cfg.Models) > 0 { + body["supported_models"] = cfg.Models + } + + _, err := apiPost(cfg, "/api/lem/workers/register", body) + return err +} + +func workerHeartbeat(cfg *WorkerConfig) { + body := map[string]interface{}{ + "worker_id": cfg.WorkerID, + } + apiPost(cfg, "/api/lem/workers/heartbeat", body) +} + +func workerPoll(cfg *WorkerConfig) int { + url := fmt.Sprintf("/api/lem/tasks/next?worker_id=%s&limit=%d", cfg.WorkerID, cfg.BatchSize) + if cfg.TaskType != "" { + url += "&type=" + cfg.TaskType + } + + resp, err := apiGet(cfg, url) + if err != nil { + log.Printf("Error fetching tasks: %v", err) + return 0 + } + + var result struct { + Tasks []APITask `json:"tasks"` + Count int `json:"count"` + } + if err := json.Unmarshal(resp, &result); err != nil { + log.Printf("Error parsing tasks: %v", err) + return 0 + } + + if result.Count == 0 { + return 0 + } + + log.Printf("Got %d tasks", result.Count) + processed := 0 + + for _, task := range result.Tasks { + if err := workerProcessTask(cfg, task); err != nil { + log.Printf("Task %d failed: %v", task.ID, err) + apiDelete(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{ + "worker_id": cfg.WorkerID, + }) + continue + } + processed++ + } + + return processed +} + +func workerProcessTask(cfg *WorkerConfig, task APITask) error { + log.Printf("Processing task %d: %s [%s/%s] %d chars prompt", + task.ID, task.TaskType, task.Language, task.Domain, len(task.PromptText)) + + _, err := apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/claim", task.ID), map[string]interface{}{ + "worker_id": cfg.WorkerID, + }) + if err != nil { + return fmt.Errorf("claim: %w", err) + } + + apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{ + "worker_id": cfg.WorkerID, + "status": "in_progress", + }) + + if cfg.DryRun { + log.Printf(" [DRY-RUN] Would generate response for: %.80s...", task.PromptText) + return nil + } + + start := time.Now() + response, err := workerInfer(cfg, task) + genTime := time.Since(start) + + if err != nil { + apiPatch(cfg, fmt.Sprintf("/api/lem/tasks/%d/status", task.ID), map[string]interface{}{ + "worker_id": cfg.WorkerID, + "status": "abandoned", + }) + return fmt.Errorf("inference: %w", err) + } + + modelUsed := task.ModelName + if modelUsed == "" { + modelUsed = "default" + } + + _, err = apiPost(cfg, fmt.Sprintf("/api/lem/tasks/%d/result", task.ID), map[string]interface{}{ + "worker_id": cfg.WorkerID, + "response_text": response, + "model_used": modelUsed, + "gen_time_ms": int(genTime.Milliseconds()), + }) + if err != nil { + return fmt.Errorf("submit result: %w", err) + } + + log.Printf(" Completed: %d chars in %v", len(response), genTime.Round(time.Millisecond)) + return nil +} + +func workerInfer(cfg *WorkerConfig, task APITask) (string, error) { + messages := []map[string]string{ + {"role": "user", "content": task.PromptText}, + } + + temp := 0.7 + maxTokens := 2048 + if task.Config != nil { + if task.Config.Temperature > 0 { + temp = task.Config.Temperature + } + if task.Config.MaxTokens > 0 { + maxTokens = task.Config.MaxTokens + } + } + + reqBody := map[string]interface{}{ + "model": task.ModelName, + "messages": messages, + "temperature": temp, + "max_tokens": maxTokens, + } + + data, err := json.Marshal(reqBody) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", cfg.InferURL+"/v1/chat/completions", bytes.NewReader(data)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("inference request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("inference HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + var chatResp struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.Unmarshal(body, &chatResp); err != nil { + return "", fmt.Errorf("parse response: %w", err) + } + + if len(chatResp.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + + content := chatResp.Choices[0].Message.Content + if len(content) < 10 { + return "", fmt.Errorf("response too short: %d chars", len(content)) + } + + return content, nil +} + +// HTTP helpers for the LEM API. + +func apiGet(cfg *WorkerConfig, path string) ([]byte, error) { + req, err := http.NewRequest("GET", cfg.APIBase+path, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+cfg.APIKey) + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + return body, nil +} + +func apiPost(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "POST", path, data) +} + +func apiPatch(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "PATCH", path, data) +} + +func apiDelete(cfg *WorkerConfig, path string, data map[string]interface{}) ([]byte, error) { + return apiRequest(cfg, "DELETE", path, data) +} + +func apiRequest(cfg *WorkerConfig, method, path string, data map[string]interface{}) ([]byte, error) { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(method, cfg.APIBase+path, bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+cfg.APIKey) + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: 15 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncStr(string(body), 200)) + } + + return body, nil +} + +// MachineID returns the machine ID from /etc/machine-id or hostname fallback. +func MachineID() string { + if data, err := os.ReadFile("/etc/machine-id"); err == nil { + id := string(bytes.TrimSpace(data)) + if len(id) > 0 { + return id + } + } + h, _ := os.Hostname() + return h +} + +// Hostname returns the system hostname. +func Hostname() string { + h, _ := os.Hostname() + return h +} + +// ReadKeyFile reads the LEM API key from ~/.config/lem/api_key. +func ReadKeyFile() string { + home, _ := os.UserHomeDir() + path := filepath.Join(home, ".config", "lem", "api_key") + data, err := os.ReadFile(path) + if err != nil { + return "" + } + return string(bytes.TrimSpace(data)) +} + +// SplitComma splits a comma-separated string into trimmed parts. +func SplitComma(s string) []string { + var result []string + for _, part := range bytes.Split([]byte(s), []byte(",")) { + trimmed := bytes.TrimSpace(part) + if len(trimmed) > 0 { + result = append(result, string(trimmed)) + } + } + return result +} + +func truncStr(s string, n int) string { + if len(s) <= n { + return s + } + return s[:n] + "..." +}