feat(rag): add Go RAG implementation with Qdrant + Ollama
Add RAG (Retrieval Augmented Generation) tools for storing documentation in Qdrant vector database and querying with semantic search. This replaces the Python tools/rag implementation with a native Go solution. New commands: - core rag ingest [directory] - Ingest markdown files into Qdrant - core rag query [question] - Query vector database with semantic search - core rag collections - List and manage Qdrant collections Features: - Markdown chunking by sections and paragraphs with overlap - UTF-8 safe text handling for international content - Automatic category detection from file paths - Multiple output formats: text, JSON, LLM context injection - Environment variable support for host configuration Dependencies: - github.com/qdrant/go-client (gRPC client) - github.com/ollama/ollama/api (embeddings API) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a06715dc90
commit
b9f44cd03a
15 changed files with 1558 additions and 0 deletions
9
go.mod
9
go.mod
|
|
@ -10,6 +10,8 @@ require (
|
|||
github.com/minio/selfupdate v0.6.0
|
||||
github.com/modelcontextprotocol/go-sdk v1.2.0
|
||||
github.com/oasdiff/oasdiff v1.11.8
|
||||
github.com/ollama/ollama v0.15.4
|
||||
github.com/qdrant/go-client v1.16.2
|
||||
github.com/spf13/cobra v1.10.2
|
||||
github.com/stretchr/testify v1.11.1
|
||||
golang.org/x/mod v0.32.0
|
||||
|
|
@ -27,6 +29,8 @@ require (
|
|||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/TwiN/go-color v1.4.1 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/cloudflare/circl v1.6.3 // indirect
|
||||
github.com/cyphar/filepath-securejoin v0.6.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
|
|
@ -38,6 +42,7 @@ require (
|
|||
github.com/go-openapi/swag/jsonname v0.25.4 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
|
|
@ -60,6 +65,7 @@ require (
|
|||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
github.com/ulikunitz/xz v0.5.15 // indirect
|
||||
github.com/wI2L/jsondiff v0.7.0 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/woodsbury/decimal128 v1.4.0 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
|
|
@ -67,5 +73,8 @@ require (
|
|||
golang.org/x/crypto v0.47.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect
|
||||
google.golang.org/grpc v1.76.0 // indirect
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
)
|
||||
|
|
|
|||
38
go.sum
38
go.sum
|
|
@ -18,6 +18,10 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI
|
|||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
|
||||
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8=
|
||||
github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4=
|
||||
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
|
||||
|
|
@ -43,6 +47,10 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj
|
|||
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
|
||||
github.com/go-git/go-git/v5 v5.16.4 h1:7ajIEZHZJULcyJebDLo99bGgS0jRrOxzZG4uCk2Yb2Y=
|
||||
github.com/go-git/go-git/v5 v5.16.4/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4=
|
||||
github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80=
|
||||
github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI=
|
||||
|
|
@ -55,10 +63,14 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD
|
|||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
|
||||
|
|
@ -100,6 +112,8 @@ github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//J
|
|||
github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c=
|
||||
github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o=
|
||||
github.com/ollama/ollama v0.15.4 h1:y841GH5lsi5j5BTFyX/E+UOC3Yiw+JBfdjBVRGw+I0M=
|
||||
github.com/ollama/ollama v0.15.4/go.mod h1:4Yn3jw2hZ4VqyJ1XciYawDRE8bzv4RT3JiVZR1kCfwE=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s=
|
||||
|
|
@ -111,6 +125,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U=
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/qdrant/go-client v1.16.2 h1:UUMJJfvXTByhwhH1DwWdbkhZ2cTdvSqVkXSIfBrVWSg=
|
||||
github.com/qdrant/go-client v1.16.2/go.mod h1:I+EL3h4HRoRTeHtbfOd/4kDXwCukZfkd41j/9wryGkw=
|
||||
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
|
|
@ -146,6 +162,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY=
|
|||
github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14=
|
||||
github.com/wI2L/jsondiff v0.7.0 h1:1lH1G37GhBPqCfp/lrs91rf/2j3DktX6qYAKZkLuCQQ=
|
||||
github.com/wI2L/jsondiff v0.7.0/go.mod h1:KAEIojdQq66oJiHhDyQez2x+sRit0vIzC9KeK0yizxM=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQsBgUlc=
|
||||
github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
|
|
@ -154,6 +172,18 @@ github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5
|
|||
github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI=
|
||||
go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
|
|
@ -192,6 +222,14 @@ golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
|
|||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
|
||||
golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba h1:UKgtfRM7Yh93Sya0Fo8ZzhDP4qBckrrxEr2oF5UIVb8=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A=
|
||||
google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c=
|
||||
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
|
|
|
|||
86
internal/cmd/rag/cmd_collections.go
Normal file
86
internal/cmd/rag/cmd_collections.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/host-uk/core/pkg/cli"
|
||||
"github.com/host-uk/core/pkg/i18n"
|
||||
"github.com/host-uk/core/pkg/rag"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
listCollections bool
|
||||
showStats bool
|
||||
deleteCollection string
|
||||
)
|
||||
|
||||
var collectionsCmd = &cobra.Command{
|
||||
Use: "collections",
|
||||
Short: i18n.T("cmd.rag.collections.short"),
|
||||
Long: i18n.T("cmd.rag.collections.long"),
|
||||
RunE: runCollections,
|
||||
}
|
||||
|
||||
func runCollections(cmd *cobra.Command, args []string) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect to Qdrant
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{
|
||||
Host: qdrantHost,
|
||||
Port: qdrantPort,
|
||||
UseTLS: false,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Qdrant: %w", err)
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
// Handle delete
|
||||
if deleteCollection != "" {
|
||||
exists, err := qdrantClient.CollectionExists(ctx, deleteCollection)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !exists {
|
||||
return fmt.Errorf("collection not found: %s", deleteCollection)
|
||||
}
|
||||
if err := qdrantClient.DeleteCollection(ctx, deleteCollection); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Deleted collection: %s\n", deleteCollection)
|
||||
return nil
|
||||
}
|
||||
|
||||
// List collections
|
||||
collections, err := qdrantClient.ListCollections(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(collections) == 0 {
|
||||
fmt.Println("No collections found.")
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Printf("%s\n\n", cli.TitleStyle.Render("Collections"))
|
||||
|
||||
for _, name := range collections {
|
||||
if showStats {
|
||||
info, err := qdrantClient.CollectionInfo(ctx, name)
|
||||
if err != nil {
|
||||
fmt.Printf(" %s (error: %v)\n", name, err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf(" %s\n", cli.ValueStyle.Render(name))
|
||||
fmt.Printf(" Points: %d\n", info.PointsCount)
|
||||
fmt.Printf(" Status: %s\n", info.Status.String())
|
||||
fmt.Println()
|
||||
} else {
|
||||
fmt.Printf(" %s\n", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
25
internal/cmd/rag/cmd_commands.go
Normal file
25
internal/cmd/rag/cmd_commands.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Package rag provides RAG (Retrieval Augmented Generation) commands.
|
||||
//
|
||||
// Commands:
|
||||
// - core rag ingest: Ingest markdown files into Qdrant
|
||||
// - core rag query: Query the vector database
|
||||
// - core rag collections: List and manage collections
|
||||
package rag
|
||||
|
||||
import (
|
||||
"github.com/host-uk/core/pkg/cli"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func init() {
|
||||
cli.RegisterCommands(AddRAGCommands)
|
||||
}
|
||||
|
||||
// AddRAGCommands registers the 'rag' command and all subcommands.
|
||||
func AddRAGCommands(root *cobra.Command) {
|
||||
initFlags()
|
||||
ragCmd.AddCommand(ingestCmd)
|
||||
ragCmd.AddCommand(queryCmd)
|
||||
ragCmd.AddCommand(collectionsCmd)
|
||||
root.AddCommand(ragCmd)
|
||||
}
|
||||
178
internal/cmd/rag/cmd_ingest.go
Normal file
178
internal/cmd/rag/cmd_ingest.go
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/host-uk/core/pkg/cli"
|
||||
"github.com/host-uk/core/pkg/i18n"
|
||||
"github.com/host-uk/core/pkg/rag"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
collection string
|
||||
recreate bool
|
||||
chunkSize int
|
||||
chunkOverlap int
|
||||
)
|
||||
|
||||
var ingestCmd = &cobra.Command{
|
||||
Use: "ingest [directory]",
|
||||
Short: i18n.T("cmd.rag.ingest.short"),
|
||||
Long: i18n.T("cmd.rag.ingest.long"),
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
RunE: runIngest,
|
||||
}
|
||||
|
||||
func runIngest(cmd *cobra.Command, args []string) error {
|
||||
directory := "."
|
||||
if len(args) > 0 {
|
||||
directory = args[0]
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect to Qdrant
|
||||
fmt.Printf("Connecting to Qdrant at %s:%d...\n", qdrantHost, qdrantPort)
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{
|
||||
Host: qdrantHost,
|
||||
Port: qdrantPort,
|
||||
UseTLS: false,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Qdrant: %w", err)
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
if err := qdrantClient.HealthCheck(ctx); err != nil {
|
||||
return fmt.Errorf("Qdrant health check failed: %w", err)
|
||||
}
|
||||
|
||||
// Connect to Ollama
|
||||
fmt.Printf("Using embedding model: %s (via %s:%d)\n", model, ollamaHost, ollamaPort)
|
||||
ollamaClient, err := rag.NewOllamaClient(rag.OllamaConfig{
|
||||
Host: ollamaHost,
|
||||
Port: ollamaPort,
|
||||
Model: model,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Ollama: %w", err)
|
||||
}
|
||||
|
||||
if err := ollamaClient.VerifyModel(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure ingestion
|
||||
cfg := rag.IngestConfig{
|
||||
Directory: directory,
|
||||
Collection: collection,
|
||||
Recreate: recreate,
|
||||
Verbose: verbose,
|
||||
BatchSize: 100,
|
||||
Chunk: rag.ChunkConfig{
|
||||
Size: chunkSize,
|
||||
Overlap: chunkOverlap,
|
||||
},
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
progress := func(file string, chunks int, total int) {
|
||||
if verbose {
|
||||
fmt.Printf(" Processed: %s (%d chunks total)\n", file, chunks)
|
||||
} else {
|
||||
fmt.Printf("\r %s (%d chunks) ", cli.DimStyle.Render(file), chunks)
|
||||
}
|
||||
}
|
||||
|
||||
// Run ingestion
|
||||
fmt.Printf("\nIngesting from: %s\n", directory)
|
||||
if recreate {
|
||||
fmt.Printf(" (recreating collection: %s)\n", collection)
|
||||
}
|
||||
|
||||
stats, err := rag.Ingest(ctx, qdrantClient, ollamaClient, cfg, progress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Summary
|
||||
fmt.Printf("\n\n%s\n", cli.TitleStyle.Render("Ingestion complete!"))
|
||||
fmt.Printf(" Files processed: %d\n", stats.Files)
|
||||
fmt.Printf(" Chunks created: %d\n", stats.Chunks)
|
||||
if stats.Errors > 0 {
|
||||
fmt.Printf(" Errors: %s\n", cli.ErrorStyle.Render(fmt.Sprintf("%d", stats.Errors)))
|
||||
}
|
||||
fmt.Printf(" Collection: %s\n", collection)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IngestDirectory is exported for use by other packages (e.g., MCP).
|
||||
func IngestDirectory(ctx context.Context, directory, collectionName string, recreateCollection bool) error {
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
if err := qdrantClient.HealthCheck(ctx); err != nil {
|
||||
return fmt.Errorf("Qdrant health check failed: %w", err)
|
||||
}
|
||||
|
||||
ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ollamaClient.VerifyModel(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg := rag.DefaultIngestConfig()
|
||||
cfg.Directory = directory
|
||||
cfg.Collection = collectionName
|
||||
cfg.Recreate = recreateCollection
|
||||
|
||||
_, err = rag.Ingest(ctx, qdrantClient, ollamaClient, cfg, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// IngestFile is exported for use by other packages (e.g., MCP).
|
||||
func IngestFile(ctx context.Context, filePath, collectionName string) (int, error) {
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
if err := qdrantClient.HealthCheck(ctx); err != nil {
|
||||
return 0, fmt.Errorf("Qdrant health check failed: %w", err)
|
||||
}
|
||||
|
||||
ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := ollamaClient.VerifyModel(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return rag.IngestFile(ctx, qdrantClient, ollamaClient, collectionName, filePath, rag.DefaultChunkConfig())
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Check for environment variable overrides
|
||||
if host := os.Getenv("QDRANT_HOST"); host != "" {
|
||||
qdrantHost = host
|
||||
}
|
||||
if host := os.Getenv("OLLAMA_HOST"); host != "" {
|
||||
ollamaHost = host
|
||||
}
|
||||
if m := os.Getenv("EMBEDDING_MODEL"); m != "" {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
107
internal/cmd/rag/cmd_query.go
Normal file
107
internal/cmd/rag/cmd_query.go
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/host-uk/core/pkg/i18n"
|
||||
"github.com/host-uk/core/pkg/rag"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var (
|
||||
queryCollection string
|
||||
limit int
|
||||
threshold float32
|
||||
category string
|
||||
format string
|
||||
)
|
||||
|
||||
var queryCmd = &cobra.Command{
|
||||
Use: "query [question]",
|
||||
Short: i18n.T("cmd.rag.query.short"),
|
||||
Long: i18n.T("cmd.rag.query.long"),
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: runQuery,
|
||||
}
|
||||
|
||||
func runQuery(cmd *cobra.Command, args []string) error {
|
||||
question := args[0]
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect to Qdrant
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{
|
||||
Host: qdrantHost,
|
||||
Port: qdrantPort,
|
||||
UseTLS: false,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Qdrant: %w", err)
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
// Connect to Ollama
|
||||
ollamaClient, err := rag.NewOllamaClient(rag.OllamaConfig{
|
||||
Host: ollamaHost,
|
||||
Port: ollamaPort,
|
||||
Model: model,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to Ollama: %w", err)
|
||||
}
|
||||
|
||||
// Configure query
|
||||
cfg := rag.QueryConfig{
|
||||
Collection: queryCollection,
|
||||
Limit: uint64(limit),
|
||||
Threshold: threshold,
|
||||
Category: category,
|
||||
}
|
||||
|
||||
// Run query
|
||||
results, err := rag.Query(ctx, qdrantClient, ollamaClient, question, cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Format output
|
||||
switch format {
|
||||
case "json":
|
||||
fmt.Println(rag.FormatResultsJSON(results))
|
||||
case "context":
|
||||
fmt.Println(rag.FormatResultsContext(results))
|
||||
default:
|
||||
fmt.Println(rag.FormatResultsText(results))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryDocs is exported for use by other packages (e.g., MCP).
|
||||
func QueryDocs(ctx context.Context, question, collectionName string, topK int) ([]rag.QueryResult, error) {
|
||||
qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer qdrantClient.Close()
|
||||
|
||||
ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg := rag.DefaultQueryConfig()
|
||||
cfg.Collection = collectionName
|
||||
cfg.Limit = uint64(topK)
|
||||
|
||||
return rag.Query(ctx, qdrantClient, ollamaClient, question, cfg)
|
||||
}
|
||||
|
||||
// QueryDocsContext is exported and returns context-formatted results.
|
||||
func QueryDocsContext(ctx context.Context, question, collectionName string, topK int) (string, error) {
|
||||
results, err := QueryDocs(ctx, question, collectionName, topK)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rag.FormatResultsContext(results), nil
|
||||
}
|
||||
54
internal/cmd/rag/cmd_rag.go
Normal file
54
internal/cmd/rag/cmd_rag.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"github.com/host-uk/core/pkg/i18n"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// Shared flags
|
||||
var (
|
||||
qdrantHost string
|
||||
qdrantPort int
|
||||
ollamaHost string
|
||||
ollamaPort int
|
||||
model string
|
||||
verbose bool
|
||||
)
|
||||
|
||||
var ragCmd = &cobra.Command{
|
||||
Use: "rag",
|
||||
Short: i18n.T("cmd.rag.short"),
|
||||
Long: i18n.T("cmd.rag.long"),
|
||||
}
|
||||
|
||||
func initFlags() {
|
||||
// Qdrant connection flags (persistent) - defaults to localhost for local development
|
||||
ragCmd.PersistentFlags().StringVar(&qdrantHost, "qdrant-host", "localhost", i18n.T("cmd.rag.flag.qdrant_host"))
|
||||
ragCmd.PersistentFlags().IntVar(&qdrantPort, "qdrant-port", 6334, i18n.T("cmd.rag.flag.qdrant_port"))
|
||||
|
||||
// Ollama connection flags (persistent) - defaults to localhost for local development
|
||||
ragCmd.PersistentFlags().StringVar(&ollamaHost, "ollama-host", "localhost", i18n.T("cmd.rag.flag.ollama_host"))
|
||||
ragCmd.PersistentFlags().IntVar(&ollamaPort, "ollama-port", 11434, i18n.T("cmd.rag.flag.ollama_port"))
|
||||
ragCmd.PersistentFlags().StringVar(&model, "model", "nomic-embed-text", i18n.T("cmd.rag.flag.model"))
|
||||
|
||||
// Verbose flag (persistent)
|
||||
ragCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, i18n.T("common.flag.verbose"))
|
||||
|
||||
// Ingest command flags
|
||||
ingestCmd.Flags().StringVar(&collection, "collection", "hostuk-docs", i18n.T("cmd.rag.ingest.flag.collection"))
|
||||
ingestCmd.Flags().BoolVar(&recreate, "recreate", false, i18n.T("cmd.rag.ingest.flag.recreate"))
|
||||
ingestCmd.Flags().IntVar(&chunkSize, "chunk-size", 500, i18n.T("cmd.rag.ingest.flag.chunk_size"))
|
||||
ingestCmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 50, i18n.T("cmd.rag.ingest.flag.chunk_overlap"))
|
||||
|
||||
// Query command flags
|
||||
queryCmd.Flags().StringVar(&queryCollection, "collection", "hostuk-docs", i18n.T("cmd.rag.query.flag.collection"))
|
||||
queryCmd.Flags().IntVar(&limit, "top", 5, i18n.T("cmd.rag.query.flag.top"))
|
||||
queryCmd.Flags().Float32Var(&threshold, "threshold", 0.5, i18n.T("cmd.rag.query.flag.threshold"))
|
||||
queryCmd.Flags().StringVar(&category, "category", "", i18n.T("cmd.rag.query.flag.category"))
|
||||
queryCmd.Flags().StringVar(&format, "format", "text", i18n.T("cmd.rag.query.flag.format"))
|
||||
|
||||
// Collections command flags
|
||||
collectionsCmd.Flags().BoolVar(&listCollections, "list", false, i18n.T("cmd.rag.collections.flag.list"))
|
||||
collectionsCmd.Flags().BoolVar(&showStats, "stats", false, i18n.T("cmd.rag.collections.flag.stats"))
|
||||
collectionsCmd.Flags().StringVar(&deleteCollection, "delete", "", i18n.T("cmd.rag.collections.flag.delete"))
|
||||
}
|
||||
|
|
@ -20,6 +20,7 @@
|
|||
// - test: Test runner with coverage
|
||||
// - qa: Quality assurance workflows
|
||||
// - monitor: Security monitoring aggregation
|
||||
// - rag: RAG (Retrieval Augmented Generation) tools
|
||||
|
||||
package variants
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ import (
|
|||
_ "github.com/host-uk/core/internal/cmd/php"
|
||||
_ "github.com/host-uk/core/internal/cmd/pkgcmd"
|
||||
_ "github.com/host-uk/core/internal/cmd/qa"
|
||||
_ "github.com/host-uk/core/internal/cmd/rag"
|
||||
_ "github.com/host-uk/core/internal/cmd/sdk"
|
||||
_ "github.com/host-uk/core/internal/cmd/security"
|
||||
_ "github.com/host-uk/core/internal/cmd/setup"
|
||||
|
|
|
|||
|
|
@ -608,6 +608,33 @@
|
|||
"no_findings": "No security findings",
|
||||
"error.no_repos": "No repositories to scan. Use --repo, --all, or run from a git repo",
|
||||
"error.not_git_repo": "Not in a git repository. Use --repo to specify one"
|
||||
},
|
||||
"rag": {
|
||||
"short": "RAG (Retrieval Augmented Generation) tools",
|
||||
"long": "RAG tools for storing documentation in Qdrant vector database and querying with semantic search. Eliminates need to repeatedly remind Claude about project specifics.",
|
||||
"flag.qdrant_host": "Qdrant server hostname",
|
||||
"flag.qdrant_port": "Qdrant gRPC port",
|
||||
"flag.ollama_host": "Ollama server hostname",
|
||||
"flag.ollama_port": "Ollama server port",
|
||||
"flag.model": "Embedding model name",
|
||||
"ingest.short": "Ingest markdown files into Qdrant",
|
||||
"ingest.long": "Ingest markdown files from a directory into Qdrant vector database. Chunks files, generates embeddings via Ollama, and stores for semantic search.",
|
||||
"ingest.flag.collection": "Qdrant collection name",
|
||||
"ingest.flag.recreate": "Delete and recreate collection",
|
||||
"ingest.flag.chunk_size": "Characters per chunk",
|
||||
"ingest.flag.chunk_overlap": "Overlap between chunks",
|
||||
"query.short": "Query the vector database",
|
||||
"query.long": "Search for similar documents using semantic similarity. Returns relevant chunks ranked by score.",
|
||||
"query.flag.collection": "Qdrant collection name",
|
||||
"query.flag.top": "Number of results to return",
|
||||
"query.flag.threshold": "Minimum similarity score (0-1)",
|
||||
"query.flag.category": "Filter by category",
|
||||
"query.flag.format": "Output format (text, json, context)",
|
||||
"collections.short": "List and manage collections",
|
||||
"collections.long": "List available collections, show statistics, or delete collections from Qdrant.",
|
||||
"collections.flag.list": "List all collections",
|
||||
"collections.flag.stats": "Show collection statistics",
|
||||
"collections.flag.delete": "Delete a collection"
|
||||
}
|
||||
},
|
||||
"common": {
|
||||
|
|
|
|||
197
pkg/rag/chunk.go
Normal file
197
pkg/rag/chunk.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ChunkConfig holds chunking configuration.
|
||||
type ChunkConfig struct {
|
||||
Size int // Characters per chunk
|
||||
Overlap int // Overlap between chunks
|
||||
}
|
||||
|
||||
// DefaultChunkConfig returns default chunking configuration.
|
||||
func DefaultChunkConfig() ChunkConfig {
|
||||
return ChunkConfig{
|
||||
Size: 500,
|
||||
Overlap: 50,
|
||||
}
|
||||
}
|
||||
|
||||
// Chunk represents a text chunk with metadata.
|
||||
type Chunk struct {
|
||||
Text string
|
||||
Section string
|
||||
Index int
|
||||
}
|
||||
|
||||
// ChunkMarkdown splits markdown text into chunks by sections and paragraphs.
|
||||
// Preserves context with configurable overlap.
|
||||
func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk {
|
||||
var chunks []Chunk
|
||||
|
||||
// Split by ## headers
|
||||
sections := splitBySections(text)
|
||||
|
||||
chunkIndex := 0
|
||||
for _, section := range sections {
|
||||
section = strings.TrimSpace(section)
|
||||
if section == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract section title
|
||||
lines := strings.SplitN(section, "\n", 2)
|
||||
title := ""
|
||||
if strings.HasPrefix(lines[0], "#") {
|
||||
title = strings.TrimLeft(lines[0], "#")
|
||||
title = strings.TrimSpace(title)
|
||||
}
|
||||
|
||||
// If section is small enough, yield as-is
|
||||
if len(section) <= cfg.Size {
|
||||
chunks = append(chunks, Chunk{
|
||||
Text: section,
|
||||
Section: title,
|
||||
Index: chunkIndex,
|
||||
})
|
||||
chunkIndex++
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, chunk by paragraphs
|
||||
paragraphs := splitByParagraphs(section)
|
||||
currentChunk := ""
|
||||
|
||||
for _, para := range paragraphs {
|
||||
para = strings.TrimSpace(para)
|
||||
if para == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(currentChunk)+len(para)+2 <= cfg.Size {
|
||||
if currentChunk != "" {
|
||||
currentChunk += "\n\n" + para
|
||||
} else {
|
||||
currentChunk = para
|
||||
}
|
||||
} else {
|
||||
if currentChunk != "" {
|
||||
chunks = append(chunks, Chunk{
|
||||
Text: strings.TrimSpace(currentChunk),
|
||||
Section: title,
|
||||
Index: chunkIndex,
|
||||
})
|
||||
chunkIndex++
|
||||
}
|
||||
// Start new chunk with overlap from previous (rune-safe for UTF-8)
|
||||
runes := []rune(currentChunk)
|
||||
if cfg.Overlap > 0 && len(runes) > cfg.Overlap {
|
||||
overlapText := string(runes[len(runes)-cfg.Overlap:])
|
||||
currentChunk = overlapText + "\n\n" + para
|
||||
} else {
|
||||
currentChunk = para
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Don't forget the last chunk
|
||||
if strings.TrimSpace(currentChunk) != "" {
|
||||
chunks = append(chunks, Chunk{
|
||||
Text: strings.TrimSpace(currentChunk),
|
||||
Section: title,
|
||||
Index: chunkIndex,
|
||||
})
|
||||
chunkIndex++
|
||||
}
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
// splitBySections splits text by ## headers while preserving the header with its content.
|
||||
func splitBySections(text string) []string {
|
||||
var sections []string
|
||||
lines := strings.Split(text, "\n")
|
||||
|
||||
var currentSection strings.Builder
|
||||
for _, line := range lines {
|
||||
// Check if this line is a ## header
|
||||
if strings.HasPrefix(line, "## ") {
|
||||
// Save previous section if exists
|
||||
if currentSection.Len() > 0 {
|
||||
sections = append(sections, currentSection.String())
|
||||
currentSection.Reset()
|
||||
}
|
||||
}
|
||||
currentSection.WriteString(line)
|
||||
currentSection.WriteString("\n")
|
||||
}
|
||||
|
||||
// Don't forget the last section
|
||||
if currentSection.Len() > 0 {
|
||||
sections = append(sections, currentSection.String())
|
||||
}
|
||||
|
||||
return sections
|
||||
}
|
||||
|
||||
// splitByParagraphs splits text by double newlines.
|
||||
func splitByParagraphs(text string) []string {
|
||||
// Replace multiple newlines with a marker, then split
|
||||
normalized := text
|
||||
for strings.Contains(normalized, "\n\n\n") {
|
||||
normalized = strings.ReplaceAll(normalized, "\n\n\n", "\n\n")
|
||||
}
|
||||
return strings.Split(normalized, "\n\n")
|
||||
}
|
||||
|
||||
// Category determines the document category from file path.
|
||||
func Category(path string) string {
|
||||
lower := strings.ToLower(path)
|
||||
|
||||
switch {
|
||||
case strings.Contains(lower, "flux") || strings.Contains(lower, "ui/component"):
|
||||
return "ui-component"
|
||||
case strings.Contains(lower, "brand") || strings.Contains(lower, "mascot"):
|
||||
return "brand"
|
||||
case strings.Contains(lower, "brief"):
|
||||
return "product-brief"
|
||||
case strings.Contains(lower, "help") || strings.Contains(lower, "draft"):
|
||||
return "help-doc"
|
||||
case strings.Contains(lower, "task") || strings.Contains(lower, "plan"):
|
||||
return "task"
|
||||
case strings.Contains(lower, "architecture") || strings.Contains(lower, "migration"):
|
||||
return "architecture"
|
||||
default:
|
||||
return "documentation"
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkID generates a unique ID for a chunk.
|
||||
func ChunkID(path string, index int, text string) string {
|
||||
// Use first 100 runes of text for uniqueness (rune-safe for UTF-8)
|
||||
runes := []rune(text)
|
||||
if len(runes) > 100 {
|
||||
runes = runes[:100]
|
||||
}
|
||||
textPart := string(runes)
|
||||
data := fmt.Sprintf("%s:%d:%s", path, index, textPart)
|
||||
hash := md5.Sum([]byte(data))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// FileExtensions returns the file extensions to process.
|
||||
func FileExtensions() []string {
|
||||
return []string{".md", ".markdown", ".txt"}
|
||||
}
|
||||
|
||||
// ShouldProcess checks if a file should be processed based on extension.
|
||||
func ShouldProcess(path string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(path))
|
||||
return slices.Contains(FileExtensions(), ext)
|
||||
}
|
||||
120
pkg/rag/chunk_test.go
Normal file
120
pkg/rag/chunk_test.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestChunkMarkdown_Good_SmallSection(t *testing.T) {
|
||||
text := `# Title
|
||||
|
||||
This is a small section that fits in one chunk.
|
||||
`
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.Len(t, chunks, 1)
|
||||
assert.Contains(t, chunks[0].Text, "small section")
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Good_MultipleSections(t *testing.T) {
|
||||
text := `# Main Title
|
||||
|
||||
Introduction paragraph.
|
||||
|
||||
## Section One
|
||||
|
||||
Content for section one.
|
||||
|
||||
## Section Two
|
||||
|
||||
Content for section two.
|
||||
`
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.GreaterOrEqual(t, len(chunks), 2)
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Good_LargeSection(t *testing.T) {
|
||||
// Create a section larger than chunk size
|
||||
text := `## Large Section
|
||||
|
||||
` + repeatString("This is a test paragraph with some content. ", 50)
|
||||
|
||||
cfg := ChunkConfig{Size: 200, Overlap: 20}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
assert.Greater(t, len(chunks), 1)
|
||||
for _, chunk := range chunks {
|
||||
assert.NotEmpty(t, chunk.Text)
|
||||
assert.Equal(t, "Large Section", chunk.Section)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Good_ExtractsTitle(t *testing.T) {
|
||||
text := `## My Section Title
|
||||
|
||||
Some content here.
|
||||
`
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.Len(t, chunks, 1)
|
||||
assert.Equal(t, "My Section Title", chunks[0].Section)
|
||||
}
|
||||
|
||||
func TestCategory_Good_UIComponent(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"docs/flux/button.md", "ui-component"},
|
||||
{"ui/components/modal.md", "ui-component"},
|
||||
{"brand/vi-personality.md", "brand"},
|
||||
{"mascot/expressions.md", "brand"},
|
||||
{"product-brief.md", "product-brief"},
|
||||
{"tasks/2024-01-15-feature.md", "task"},
|
||||
{"plans/architecture.md", "task"},
|
||||
{"architecture/migration.md", "architecture"},
|
||||
{"docs/api.md", "documentation"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.path, func(t *testing.T) {
|
||||
assert.Equal(t, tc.expected, Category(tc.path))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChunkID_Good_Deterministic(t *testing.T) {
|
||||
id1 := ChunkID("test.md", 0, "hello world")
|
||||
id2 := ChunkID("test.md", 0, "hello world")
|
||||
|
||||
assert.Equal(t, id1, id2)
|
||||
}
|
||||
|
||||
func TestChunkID_Good_DifferentForDifferentInputs(t *testing.T) {
|
||||
id1 := ChunkID("test.md", 0, "hello world")
|
||||
id2 := ChunkID("test.md", 1, "hello world")
|
||||
id3 := ChunkID("other.md", 0, "hello world")
|
||||
|
||||
assert.NotEqual(t, id1, id2)
|
||||
assert.NotEqual(t, id1, id3)
|
||||
}
|
||||
|
||||
func TestShouldProcess_Good_MarkdownFiles(t *testing.T) {
|
||||
assert.True(t, ShouldProcess("doc.md"))
|
||||
assert.True(t, ShouldProcess("doc.markdown"))
|
||||
assert.True(t, ShouldProcess("doc.txt"))
|
||||
assert.False(t, ShouldProcess("doc.go"))
|
||||
assert.False(t, ShouldProcess("doc.py"))
|
||||
assert.False(t, ShouldProcess("doc"))
|
||||
}
|
||||
|
||||
// Helper function
|
||||
func repeatString(s string, n int) string {
|
||||
result := ""
|
||||
for i := 0; i < n; i++ {
|
||||
result += s
|
||||
}
|
||||
return result
|
||||
}
|
||||
214
pkg/rag/ingest.go
Normal file
214
pkg/rag/ingest.go
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IngestConfig holds ingestion configuration.
|
||||
type IngestConfig struct {
|
||||
Directory string
|
||||
Collection string
|
||||
Recreate bool
|
||||
Verbose bool
|
||||
BatchSize int
|
||||
Chunk ChunkConfig
|
||||
}
|
||||
|
||||
// DefaultIngestConfig returns default ingestion configuration.
|
||||
func DefaultIngestConfig() IngestConfig {
|
||||
return IngestConfig{
|
||||
Collection: "hostuk-docs",
|
||||
BatchSize: 100,
|
||||
Chunk: DefaultChunkConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// IngestStats holds statistics from ingestion.
|
||||
type IngestStats struct {
|
||||
Files int
|
||||
Chunks int
|
||||
Errors int
|
||||
}
|
||||
|
||||
// IngestProgress is called during ingestion to report progress.
|
||||
type IngestProgress func(file string, chunks int, total int)
|
||||
|
||||
// Ingest processes a directory of documents and stores them in Qdrant.
|
||||
func Ingest(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, cfg IngestConfig, progress IngestProgress) (*IngestStats, error) {
|
||||
stats := &IngestStats{}
|
||||
|
||||
// Validate batch size to prevent infinite loop
|
||||
if cfg.BatchSize <= 0 {
|
||||
cfg.BatchSize = 100 // Safe default
|
||||
}
|
||||
|
||||
// Resolve directory
|
||||
absDir, err := filepath.Abs(cfg.Directory)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error resolving directory: %w", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(absDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error accessing directory: %w", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, fmt.Errorf("not a directory: %s", absDir)
|
||||
}
|
||||
|
||||
// Check/create collection
|
||||
exists, err := qdrant.CollectionExists(ctx, cfg.Collection)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error checking collection: %w", err)
|
||||
}
|
||||
|
||||
if cfg.Recreate && exists {
|
||||
if err := qdrant.DeleteCollection(ctx, cfg.Collection); err != nil {
|
||||
return nil, fmt.Errorf("error deleting collection: %w", err)
|
||||
}
|
||||
exists = false
|
||||
}
|
||||
|
||||
if !exists {
|
||||
vectorDim := ollama.EmbedDimension()
|
||||
if err := qdrant.CreateCollection(ctx, cfg.Collection, vectorDim); err != nil {
|
||||
return nil, fmt.Errorf("error creating collection: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Find markdown files
|
||||
var files []string
|
||||
err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !d.IsDir() && ShouldProcess(path) {
|
||||
files = append(files, path)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error walking directory: %w", err)
|
||||
}
|
||||
|
||||
if len(files) == 0 {
|
||||
return nil, fmt.Errorf("no markdown files found in %s", absDir)
|
||||
}
|
||||
|
||||
// Process files
|
||||
var points []Point
|
||||
for _, filePath := range files {
|
||||
relPath, err := filepath.Rel(absDir, filePath)
|
||||
if err != nil {
|
||||
stats.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
stats.Errors++
|
||||
continue
|
||||
}
|
||||
|
||||
if len(strings.TrimSpace(string(content))) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Chunk the content
|
||||
category := Category(relPath)
|
||||
chunks := ChunkMarkdown(string(content), cfg.Chunk)
|
||||
|
||||
for _, chunk := range chunks {
|
||||
// Generate embedding
|
||||
embedding, err := ollama.Embed(ctx, chunk.Text)
|
||||
if err != nil {
|
||||
stats.Errors++
|
||||
if cfg.Verbose {
|
||||
fmt.Printf(" Error embedding %s chunk %d: %v\n", relPath, chunk.Index, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Create point
|
||||
points = append(points, Point{
|
||||
ID: ChunkID(relPath, chunk.Index, chunk.Text),
|
||||
Vector: embedding,
|
||||
Payload: map[string]any{
|
||||
"text": chunk.Text,
|
||||
"source": relPath,
|
||||
"section": chunk.Section,
|
||||
"category": category,
|
||||
"chunk_index": chunk.Index,
|
||||
},
|
||||
})
|
||||
stats.Chunks++
|
||||
}
|
||||
|
||||
stats.Files++
|
||||
if progress != nil {
|
||||
progress(relPath, stats.Chunks, len(files))
|
||||
}
|
||||
}
|
||||
|
||||
// Batch upsert to Qdrant
|
||||
if len(points) > 0 {
|
||||
for i := 0; i < len(points); i += cfg.BatchSize {
|
||||
end := i + cfg.BatchSize
|
||||
if end > len(points) {
|
||||
end = len(points)
|
||||
}
|
||||
batch := points[i:end]
|
||||
if err := qdrant.UpsertPoints(ctx, cfg.Collection, batch); err != nil {
|
||||
return stats, fmt.Errorf("error upserting batch %d: %w", i/cfg.BatchSize+1, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// IngestFile processes a single file and stores it in Qdrant.
|
||||
func IngestFile(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, collection string, filePath string, chunkCfg ChunkConfig) (int, error) {
|
||||
content, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error reading file: %w", err)
|
||||
}
|
||||
|
||||
if len(strings.TrimSpace(string(content))) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
category := Category(filePath)
|
||||
chunks := ChunkMarkdown(string(content), chunkCfg)
|
||||
|
||||
var points []Point
|
||||
for _, chunk := range chunks {
|
||||
embedding, err := ollama.Embed(ctx, chunk.Text)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error embedding chunk %d: %w", chunk.Index, err)
|
||||
}
|
||||
|
||||
points = append(points, Point{
|
||||
ID: ChunkID(filePath, chunk.Index, chunk.Text),
|
||||
Vector: embedding,
|
||||
Payload: map[string]any{
|
||||
"text": chunk.Text,
|
||||
"source": filePath,
|
||||
"section": chunk.Section,
|
||||
"category": category,
|
||||
"chunk_index": chunk.Index,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if err := qdrant.UpsertPoints(ctx, collection, points); err != nil {
|
||||
return 0, fmt.Errorf("error upserting points: %w", err)
|
||||
}
|
||||
|
||||
return len(points), nil
|
||||
}
|
||||
116
pkg/rag/ollama.go
Normal file
116
pkg/rag/ollama.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// OllamaConfig holds Ollama connection configuration.
|
||||
type OllamaConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Model string
|
||||
}
|
||||
|
||||
// DefaultOllamaConfig returns default Ollama configuration.
|
||||
// Host defaults to localhost for local development.
|
||||
func DefaultOllamaConfig() OllamaConfig {
|
||||
return OllamaConfig{
|
||||
Host: "localhost",
|
||||
Port: 11434,
|
||||
Model: "nomic-embed-text",
|
||||
}
|
||||
}
|
||||
|
||||
// OllamaClient wraps the Ollama API client for embeddings.
|
||||
type OllamaClient struct {
|
||||
client *api.Client
|
||||
config OllamaConfig
|
||||
}
|
||||
|
||||
// NewOllamaClient creates a new Ollama client.
|
||||
func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) {
|
||||
baseURL := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
|
||||
}
|
||||
|
||||
client := api.NewClient(baseURL, http.DefaultClient)
|
||||
|
||||
return &OllamaClient{
|
||||
client: client,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EmbedDimension returns the embedding dimension for the configured model.
|
||||
// nomic-embed-text uses 768 dimensions.
|
||||
func (o *OllamaClient) EmbedDimension() uint64 {
|
||||
switch o.config.Model {
|
||||
case "nomic-embed-text":
|
||||
return 768
|
||||
case "mxbai-embed-large":
|
||||
return 1024
|
||||
case "all-minilm":
|
||||
return 384
|
||||
default:
|
||||
return 768 // Default to nomic-embed-text dimension
|
||||
}
|
||||
}
|
||||
|
||||
// Embed generates embeddings for the given text.
|
||||
func (o *OllamaClient) Embed(ctx context.Context, text string) ([]float32, error) {
|
||||
req := &api.EmbedRequest{
|
||||
Model: o.config.Model,
|
||||
Input: text,
|
||||
}
|
||||
|
||||
resp, err := o.client.Embed(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate embedding: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
|
||||
return nil, fmt.Errorf("empty embedding response")
|
||||
}
|
||||
|
||||
// Convert float64 to float32 for Qdrant
|
||||
embedding := resp.Embeddings[0]
|
||||
result := make([]float32, len(embedding))
|
||||
for i, v := range embedding {
|
||||
result[i] = float32(v)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// EmbedBatch generates embeddings for multiple texts.
|
||||
func (o *OllamaClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
|
||||
results := make([][]float32, len(texts))
|
||||
for i, text := range texts {
|
||||
embedding, err := o.Embed(ctx, text)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to embed text %d: %w", i, err)
|
||||
}
|
||||
results[i] = embedding
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// VerifyModel checks if the embedding model is available.
|
||||
func (o *OllamaClient) VerifyModel(ctx context.Context) error {
|
||||
_, err := o.Embed(ctx, "test")
|
||||
if err != nil {
|
||||
return fmt.Errorf("model %s not available: %w (run: ollama pull %s)", o.config.Model, err, o.config.Model)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Model returns the configured embedding model name.
|
||||
func (o *OllamaClient) Model() string {
|
||||
return o.config.Model
|
||||
}
|
||||
224
pkg/rag/qdrant.go
Normal file
224
pkg/rag/qdrant.go
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
// Package rag provides RAG (Retrieval Augmented Generation) functionality
|
||||
// for storing and querying documentation in Qdrant vector database.
|
||||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
)
|
||||
|
||||
// QdrantConfig holds Qdrant connection configuration.
|
||||
type QdrantConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
APIKey string
|
||||
UseTLS bool
|
||||
}
|
||||
|
||||
// DefaultQdrantConfig returns default Qdrant configuration.
|
||||
// Host defaults to localhost for local development.
|
||||
func DefaultQdrantConfig() QdrantConfig {
|
||||
return QdrantConfig{
|
||||
Host: "localhost",
|
||||
Port: 6334, // gRPC port
|
||||
UseTLS: false,
|
||||
}
|
||||
}
|
||||
|
||||
// QdrantClient wraps the Qdrant Go client with convenience methods.
|
||||
type QdrantClient struct {
|
||||
client *qdrant.Client
|
||||
config QdrantConfig
|
||||
}
|
||||
|
||||
// NewQdrantClient creates a new Qdrant client.
|
||||
func NewQdrantClient(cfg QdrantConfig) (*QdrantClient, error) {
|
||||
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
|
||||
|
||||
client, err := qdrant.NewClient(&qdrant.Config{
|
||||
Host: cfg.Host,
|
||||
Port: cfg.Port,
|
||||
APIKey: cfg.APIKey,
|
||||
UseTLS: cfg.UseTLS,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Qdrant at %s: %w", addr, err)
|
||||
}
|
||||
|
||||
return &QdrantClient{
|
||||
client: client,
|
||||
config: cfg,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the Qdrant client connection.
|
||||
func (q *QdrantClient) Close() error {
|
||||
return q.client.Close()
|
||||
}
|
||||
|
||||
// HealthCheck verifies the connection to Qdrant.
|
||||
func (q *QdrantClient) HealthCheck(ctx context.Context) error {
|
||||
_, err := q.client.HealthCheck(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListCollections returns all collection names.
|
||||
func (q *QdrantClient) ListCollections(ctx context.Context) ([]string, error) {
|
||||
resp, err := q.client.ListCollections(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names := make([]string, len(resp))
|
||||
copy(names, resp)
|
||||
return names, nil
|
||||
}
|
||||
|
||||
// CollectionExists checks if a collection exists.
|
||||
func (q *QdrantClient) CollectionExists(ctx context.Context, name string) (bool, error) {
|
||||
return q.client.CollectionExists(ctx, name)
|
||||
}
|
||||
|
||||
// CreateCollection creates a new collection with cosine distance.
|
||||
func (q *QdrantClient) CreateCollection(ctx context.Context, name string, vectorSize uint64) error {
|
||||
return q.client.CreateCollection(ctx, &qdrant.CreateCollection{
|
||||
CollectionName: name,
|
||||
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
|
||||
Size: vectorSize,
|
||||
Distance: qdrant.Distance_Cosine,
|
||||
}),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteCollection deletes a collection.
|
||||
func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error {
|
||||
return q.client.DeleteCollection(ctx, name)
|
||||
}
|
||||
|
||||
// CollectionInfo returns information about a collection.
|
||||
func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) {
|
||||
return q.client.GetCollectionInfo(ctx, name)
|
||||
}
|
||||
|
||||
// Point represents a vector point with payload.
|
||||
type Point struct {
|
||||
ID string
|
||||
Vector []float32
|
||||
Payload map[string]any
|
||||
}
|
||||
|
||||
// UpsertPoints inserts or updates points in a collection.
|
||||
func (q *QdrantClient) UpsertPoints(ctx context.Context, collection string, points []Point) error {
|
||||
if len(points) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
qdrantPoints := make([]*qdrant.PointStruct, len(points))
|
||||
for i, p := range points {
|
||||
qdrantPoints[i] = &qdrant.PointStruct{
|
||||
Id: qdrant.NewID(p.ID),
|
||||
Vectors: qdrant.NewVectors(p.Vector...),
|
||||
Payload: qdrant.NewValueMap(p.Payload),
|
||||
}
|
||||
}
|
||||
|
||||
_, err := q.client.Upsert(ctx, &qdrant.UpsertPoints{
|
||||
CollectionName: collection,
|
||||
Points: qdrantPoints,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// SearchResult represents a search result with score.
|
||||
type SearchResult struct {
|
||||
ID string
|
||||
Score float32
|
||||
Payload map[string]any
|
||||
}
|
||||
|
||||
// Search performs a vector similarity search.
|
||||
func (q *QdrantClient) Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error) {
|
||||
query := &qdrant.QueryPoints{
|
||||
CollectionName: collection,
|
||||
Query: qdrant.NewQuery(vector...),
|
||||
Limit: qdrant.PtrOf(limit),
|
||||
WithPayload: qdrant.NewWithPayload(true),
|
||||
}
|
||||
|
||||
// Add filter if provided
|
||||
if len(filter) > 0 {
|
||||
conditions := make([]*qdrant.Condition, 0, len(filter))
|
||||
for k, v := range filter {
|
||||
conditions = append(conditions, qdrant.NewMatch(k, v))
|
||||
}
|
||||
query.Filter = &qdrant.Filter{
|
||||
Must: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := q.client.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make([]SearchResult, len(resp))
|
||||
for i, p := range resp {
|
||||
payload := make(map[string]any)
|
||||
for k, v := range p.Payload {
|
||||
payload[k] = valueToGo(v)
|
||||
}
|
||||
results[i] = SearchResult{
|
||||
ID: pointIDToString(p.Id),
|
||||
Score: p.Score,
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// pointIDToString converts a Qdrant point ID to string.
|
||||
func pointIDToString(id *qdrant.PointId) string {
|
||||
if id == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := id.PointIdOptions.(type) {
|
||||
case *qdrant.PointId_Num:
|
||||
return fmt.Sprintf("%d", v.Num)
|
||||
case *qdrant.PointId_Uuid:
|
||||
return v.Uuid
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// valueToGo converts a Qdrant value to a Go value.
|
||||
func valueToGo(v *qdrant.Value) any {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
switch val := v.Kind.(type) {
|
||||
case *qdrant.Value_StringValue:
|
||||
return val.StringValue
|
||||
case *qdrant.Value_IntegerValue:
|
||||
return val.IntegerValue
|
||||
case *qdrant.Value_DoubleValue:
|
||||
return val.DoubleValue
|
||||
case *qdrant.Value_BoolValue:
|
||||
return val.BoolValue
|
||||
case *qdrant.Value_ListValue:
|
||||
list := make([]any, len(val.ListValue.Values))
|
||||
for i, item := range val.ListValue.Values {
|
||||
list[i] = valueToGo(item)
|
||||
}
|
||||
return list
|
||||
case *qdrant.Value_StructValue:
|
||||
m := make(map[string]any)
|
||||
for k, item := range val.StructValue.Fields {
|
||||
m[k] = valueToGo(item)
|
||||
}
|
||||
return m
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
161
pkg/rag/query.go
Normal file
161
pkg/rag/query.go
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// QueryConfig holds query configuration.
|
||||
type QueryConfig struct {
|
||||
Collection string
|
||||
Limit uint64
|
||||
Threshold float32 // Minimum similarity score (0-1)
|
||||
Category string // Filter by category
|
||||
}
|
||||
|
||||
// DefaultQueryConfig returns default query configuration.
|
||||
func DefaultQueryConfig() QueryConfig {
|
||||
return QueryConfig{
|
||||
Collection: "hostuk-docs",
|
||||
Limit: 5,
|
||||
Threshold: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryResult represents a query result with metadata.
|
||||
type QueryResult struct {
|
||||
Text string
|
||||
Source string
|
||||
Section string
|
||||
Category string
|
||||
ChunkIndex int
|
||||
Score float32
|
||||
}
|
||||
|
||||
// Query searches for similar documents in Qdrant.
|
||||
func Query(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, query string, cfg QueryConfig) ([]QueryResult, error) {
|
||||
// Generate embedding for query
|
||||
embedding, err := ollama.Embed(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error generating query embedding: %w", err)
|
||||
}
|
||||
|
||||
// Build filter
|
||||
var filter map[string]string
|
||||
if cfg.Category != "" {
|
||||
filter = map[string]string{"category": cfg.Category}
|
||||
}
|
||||
|
||||
// Search Qdrant
|
||||
results, err := qdrant.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error searching: %w", err)
|
||||
}
|
||||
|
||||
// Convert and filter by threshold
|
||||
var queryResults []QueryResult
|
||||
for _, r := range results {
|
||||
if r.Score < cfg.Threshold {
|
||||
continue
|
||||
}
|
||||
|
||||
qr := QueryResult{
|
||||
Score: r.Score,
|
||||
}
|
||||
|
||||
// Extract payload fields
|
||||
if text, ok := r.Payload["text"].(string); ok {
|
||||
qr.Text = text
|
||||
}
|
||||
if source, ok := r.Payload["source"].(string); ok {
|
||||
qr.Source = source
|
||||
}
|
||||
if section, ok := r.Payload["section"].(string); ok {
|
||||
qr.Section = section
|
||||
}
|
||||
if category, ok := r.Payload["category"].(string); ok {
|
||||
qr.Category = category
|
||||
}
|
||||
// Handle chunk_index from various types (JSON unmarshaling produces float64)
|
||||
switch idx := r.Payload["chunk_index"].(type) {
|
||||
case int64:
|
||||
qr.ChunkIndex = int(idx)
|
||||
case float64:
|
||||
qr.ChunkIndex = int(idx)
|
||||
case int:
|
||||
qr.ChunkIndex = idx
|
||||
}
|
||||
|
||||
queryResults = append(queryResults, qr)
|
||||
}
|
||||
|
||||
return queryResults, nil
|
||||
}
|
||||
|
||||
// FormatResultsText formats query results as plain text.
|
||||
func FormatResultsText(results []QueryResult) string {
|
||||
if len(results) == 0 {
|
||||
return "No results found."
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for i, r := range results {
|
||||
sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score))
|
||||
sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source))
|
||||
if r.Section != "" {
|
||||
sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section))
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category))
|
||||
sb.WriteString(r.Text)
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// FormatResultsContext formats query results for LLM context injection.
|
||||
func FormatResultsContext(results []QueryResult) string {
|
||||
if len(results) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("<retrieved_context>\n")
|
||||
for _, r := range results {
|
||||
// Escape XML special characters to prevent malformed output
|
||||
fmt.Fprintf(&sb, "<document source=\"%s\" section=\"%s\" category=\"%s\">\n",
|
||||
html.EscapeString(r.Source),
|
||||
html.EscapeString(r.Section),
|
||||
html.EscapeString(r.Category))
|
||||
sb.WriteString(html.EscapeString(r.Text))
|
||||
sb.WriteString("\n</document>\n\n")
|
||||
}
|
||||
sb.WriteString("</retrieved_context>")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// FormatResultsJSON formats query results as JSON-like output.
|
||||
func FormatResultsJSON(results []QueryResult) string {
|
||||
if len(results) == 0 {
|
||||
return "[]"
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[\n")
|
||||
for i, r := range results {
|
||||
sb.WriteString(" {\n")
|
||||
sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source))
|
||||
sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section))
|
||||
sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category))
|
||||
sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score))
|
||||
sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text))
|
||||
if i < len(results)-1 {
|
||||
sb.WriteString(" },\n")
|
||||
} else {
|
||||
sb.WriteString(" }\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("]")
|
||||
return sb.String()
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue