diff --git a/go.mod b/go.mod index b8665c65..82e47f5a 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,8 @@ require ( github.com/minio/selfupdate v0.6.0 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/oasdiff/oasdiff v1.11.8 + github.com/ollama/ollama v0.15.4 + github.com/qdrant/go-client v1.16.2 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 golang.org/x/mod v0.32.0 @@ -27,6 +29,8 @@ require ( github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect github.com/TwiN/go-color v1.4.1 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cloudflare/circl v1.6.3 // indirect github.com/cyphar/filepath-securejoin v0.6.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -38,6 +42,7 @@ require ( github.com/go-openapi/swag/jsonname v0.25.4 // indirect github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect github.com/google/jsonschema-go v0.4.2 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/josharian/intern v1.0.0 // indirect @@ -60,6 +65,7 @@ require ( github.com/ugorji/go/codec v1.3.0 // indirect github.com/ulikunitz/xz v0.5.15 // indirect github.com/wI2L/jsondiff v0.7.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/woodsbury/decimal128 v1.4.0 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yargevad/filepathx v1.0.0 // indirect @@ -67,5 +73,8 @@ require ( golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20260112195511-716be5621a96 // indirect golang.org/x/sys v0.40.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect + google.golang.org/grpc v1.76.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect ) diff --git a/go.sum b/go.sum index 3fb17e72..8d2c3c67 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,10 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -43,6 +47,10 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= github.com/go-git/go-git/v5 v5.16.4 h1:7ajIEZHZJULcyJebDLo99bGgS0jRrOxzZG4uCk2Yb2Y= github.com/go-git/go-git/v5 v5.16.4/go.mod h1:4Ge4alE/5gPs30F2H1esi2gPd69R0C39lolkucHBOp8= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/jsonpointer v0.22.4 h1:dZtK82WlNpVLDW2jlA1YCiVJFVqkED1MegOUy9kR5T4= github.com/go-openapi/jsonpointer v0.22.4/go.mod h1:elX9+UgznpFhgBuaMQ7iu4lvvX1nvNsesQ3oxmYTw80= github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= @@ -55,10 +63,14 @@ github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeD github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8= github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= @@ -100,6 +112,8 @@ github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037 h1:G7ERwszslrBzRxj//J github.com/oasdiff/yaml v0.0.0-20250309154309-f31be36b4037/go.mod h1:2bpvgLBZEtENV5scfDFEtB/5+1M4hkQhDQrccEJ/qGw= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90 h1:bQx3WeLcUWy+RletIKwUIt4x3t8n2SxavmoclizMb8c= github.com/oasdiff/yaml3 v0.0.0-20250309153720-d2182401db90/go.mod h1:y5+oSEHCPT/DGrS++Wc/479ERge0zTFxaF8PbGKcg2o= +github.com/ollama/ollama v0.15.4 h1:y841GH5lsi5j5BTFyX/E+UOC3Yiw+JBfdjBVRGw+I0M= +github.com/ollama/ollama v0.15.4/go.mod h1:4Yn3jw2hZ4VqyJ1XciYawDRE8bzv4RT3JiVZR1kCfwE= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX50IvK2s= @@ -111,6 +125,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qdrant/go-client v1.16.2 h1:UUMJJfvXTByhwhH1DwWdbkhZ2cTdvSqVkXSIfBrVWSg= +github.com/qdrant/go-client v1.16.2/go.mod h1:I+EL3h4HRoRTeHtbfOd/4kDXwCukZfkd41j/9wryGkw= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -146,6 +162,8 @@ github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/wI2L/jsondiff v0.7.0 h1:1lH1G37GhBPqCfp/lrs91rf/2j3DktX6qYAKZkLuCQQ= github.com/wI2L/jsondiff v0.7.0/go.mod h1:KAEIojdQq66oJiHhDyQez2x+sRit0vIzC9KeK0yizxM= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQsBgUlc= github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= @@ -154,6 +172,18 @@ github.com/yargevad/filepathx v1.0.0 h1:SYcT+N3tYGi+NvazubCNlvgIPbzAk7i7y2dwg3I5 github.com/yargevad/filepathx v1.0.0/go.mod h1:BprfX/gpYNJHJfc35GjRRpVcwWXS89gGulUIU5tK3tA= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= @@ -192,6 +222,14 @@ golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba h1:UKgtfRM7Yh93Sya0Fo8ZzhDP4qBckrrxEr2oF5UIVb8= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= +google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/cmd/rag/cmd_collections.go b/internal/cmd/rag/cmd_collections.go new file mode 100644 index 00000000..b21d45c4 --- /dev/null +++ b/internal/cmd/rag/cmd_collections.go @@ -0,0 +1,86 @@ +package rag + +import ( + "context" + "fmt" + + "github.com/host-uk/core/pkg/cli" + "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/rag" + "github.com/spf13/cobra" +) + +var ( + listCollections bool + showStats bool + deleteCollection string +) + +var collectionsCmd = &cobra.Command{ + Use: "collections", + Short: i18n.T("cmd.rag.collections.short"), + Long: i18n.T("cmd.rag.collections.long"), + RunE: runCollections, +} + +func runCollections(cmd *cobra.Command, args []string) error { + ctx := context.Background() + + // Connect to Qdrant + qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{ + Host: qdrantHost, + Port: qdrantPort, + UseTLS: false, + }) + if err != nil { + return fmt.Errorf("failed to connect to Qdrant: %w", err) + } + defer qdrantClient.Close() + + // Handle delete + if deleteCollection != "" { + exists, err := qdrantClient.CollectionExists(ctx, deleteCollection) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("collection not found: %s", deleteCollection) + } + if err := qdrantClient.DeleteCollection(ctx, deleteCollection); err != nil { + return err + } + fmt.Printf("Deleted collection: %s\n", deleteCollection) + return nil + } + + // List collections + collections, err := qdrantClient.ListCollections(ctx) + if err != nil { + return err + } + + if len(collections) == 0 { + fmt.Println("No collections found.") + return nil + } + + fmt.Printf("%s\n\n", cli.TitleStyle.Render("Collections")) + + for _, name := range collections { + if showStats { + info, err := qdrantClient.CollectionInfo(ctx, name) + if err != nil { + fmt.Printf(" %s (error: %v)\n", name, err) + continue + } + fmt.Printf(" %s\n", cli.ValueStyle.Render(name)) + fmt.Printf(" Points: %d\n", info.PointsCount) + fmt.Printf(" Status: %s\n", info.Status.String()) + fmt.Println() + } else { + fmt.Printf(" %s\n", name) + } + } + + return nil +} diff --git a/internal/cmd/rag/cmd_commands.go b/internal/cmd/rag/cmd_commands.go new file mode 100644 index 00000000..b32bfc66 --- /dev/null +++ b/internal/cmd/rag/cmd_commands.go @@ -0,0 +1,25 @@ +// Package rag provides RAG (Retrieval Augmented Generation) commands. +// +// Commands: +// - core rag ingest: Ingest markdown files into Qdrant +// - core rag query: Query the vector database +// - core rag collections: List and manage collections +package rag + +import ( + "github.com/host-uk/core/pkg/cli" + "github.com/spf13/cobra" +) + +func init() { + cli.RegisterCommands(AddRAGCommands) +} + +// AddRAGCommands registers the 'rag' command and all subcommands. +func AddRAGCommands(root *cobra.Command) { + initFlags() + ragCmd.AddCommand(ingestCmd) + ragCmd.AddCommand(queryCmd) + ragCmd.AddCommand(collectionsCmd) + root.AddCommand(ragCmd) +} diff --git a/internal/cmd/rag/cmd_ingest.go b/internal/cmd/rag/cmd_ingest.go new file mode 100644 index 00000000..077a9310 --- /dev/null +++ b/internal/cmd/rag/cmd_ingest.go @@ -0,0 +1,178 @@ +package rag + +import ( + "context" + "fmt" + "os" + + "github.com/host-uk/core/pkg/cli" + "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/rag" + "github.com/spf13/cobra" +) + +var ( + collection string + recreate bool + chunkSize int + chunkOverlap int +) + +var ingestCmd = &cobra.Command{ + Use: "ingest [directory]", + Short: i18n.T("cmd.rag.ingest.short"), + Long: i18n.T("cmd.rag.ingest.long"), + Args: cobra.MaximumNArgs(1), + RunE: runIngest, +} + +func runIngest(cmd *cobra.Command, args []string) error { + directory := "." + if len(args) > 0 { + directory = args[0] + } + + ctx := context.Background() + + // Connect to Qdrant + fmt.Printf("Connecting to Qdrant at %s:%d...\n", qdrantHost, qdrantPort) + qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{ + Host: qdrantHost, + Port: qdrantPort, + UseTLS: false, + }) + if err != nil { + return fmt.Errorf("failed to connect to Qdrant: %w", err) + } + defer qdrantClient.Close() + + if err := qdrantClient.HealthCheck(ctx); err != nil { + return fmt.Errorf("Qdrant health check failed: %w", err) + } + + // Connect to Ollama + fmt.Printf("Using embedding model: %s (via %s:%d)\n", model, ollamaHost, ollamaPort) + ollamaClient, err := rag.NewOllamaClient(rag.OllamaConfig{ + Host: ollamaHost, + Port: ollamaPort, + Model: model, + }) + if err != nil { + return fmt.Errorf("failed to connect to Ollama: %w", err) + } + + if err := ollamaClient.VerifyModel(ctx); err != nil { + return err + } + + // Configure ingestion + cfg := rag.IngestConfig{ + Directory: directory, + Collection: collection, + Recreate: recreate, + Verbose: verbose, + BatchSize: 100, + Chunk: rag.ChunkConfig{ + Size: chunkSize, + Overlap: chunkOverlap, + }, + } + + // Progress callback + progress := func(file string, chunks int, total int) { + if verbose { + fmt.Printf(" Processed: %s (%d chunks total)\n", file, chunks) + } else { + fmt.Printf("\r %s (%d chunks) ", cli.DimStyle.Render(file), chunks) + } + } + + // Run ingestion + fmt.Printf("\nIngesting from: %s\n", directory) + if recreate { + fmt.Printf(" (recreating collection: %s)\n", collection) + } + + stats, err := rag.Ingest(ctx, qdrantClient, ollamaClient, cfg, progress) + if err != nil { + return err + } + + // Summary + fmt.Printf("\n\n%s\n", cli.TitleStyle.Render("Ingestion complete!")) + fmt.Printf(" Files processed: %d\n", stats.Files) + fmt.Printf(" Chunks created: %d\n", stats.Chunks) + if stats.Errors > 0 { + fmt.Printf(" Errors: %s\n", cli.ErrorStyle.Render(fmt.Sprintf("%d", stats.Errors))) + } + fmt.Printf(" Collection: %s\n", collection) + + return nil +} + +// IngestDirectory is exported for use by other packages (e.g., MCP). +func IngestDirectory(ctx context.Context, directory, collectionName string, recreateCollection bool) error { + qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig()) + if err != nil { + return err + } + defer qdrantClient.Close() + + if err := qdrantClient.HealthCheck(ctx); err != nil { + return fmt.Errorf("Qdrant health check failed: %w", err) + } + + ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig()) + if err != nil { + return err + } + + if err := ollamaClient.VerifyModel(ctx); err != nil { + return err + } + + cfg := rag.DefaultIngestConfig() + cfg.Directory = directory + cfg.Collection = collectionName + cfg.Recreate = recreateCollection + + _, err = rag.Ingest(ctx, qdrantClient, ollamaClient, cfg, nil) + return err +} + +// IngestFile is exported for use by other packages (e.g., MCP). +func IngestFile(ctx context.Context, filePath, collectionName string) (int, error) { + qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig()) + if err != nil { + return 0, err + } + defer qdrantClient.Close() + + if err := qdrantClient.HealthCheck(ctx); err != nil { + return 0, fmt.Errorf("Qdrant health check failed: %w", err) + } + + ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig()) + if err != nil { + return 0, err + } + + if err := ollamaClient.VerifyModel(ctx); err != nil { + return 0, err + } + + return rag.IngestFile(ctx, qdrantClient, ollamaClient, collectionName, filePath, rag.DefaultChunkConfig()) +} + +func init() { + // Check for environment variable overrides + if host := os.Getenv("QDRANT_HOST"); host != "" { + qdrantHost = host + } + if host := os.Getenv("OLLAMA_HOST"); host != "" { + ollamaHost = host + } + if m := os.Getenv("EMBEDDING_MODEL"); m != "" { + model = m + } +} diff --git a/internal/cmd/rag/cmd_query.go b/internal/cmd/rag/cmd_query.go new file mode 100644 index 00000000..69b2b9aa --- /dev/null +++ b/internal/cmd/rag/cmd_query.go @@ -0,0 +1,107 @@ +package rag + +import ( + "context" + "fmt" + + "github.com/host-uk/core/pkg/i18n" + "github.com/host-uk/core/pkg/rag" + "github.com/spf13/cobra" +) + +var ( + queryCollection string + limit int + threshold float32 + category string + format string +) + +var queryCmd = &cobra.Command{ + Use: "query [question]", + Short: i18n.T("cmd.rag.query.short"), + Long: i18n.T("cmd.rag.query.long"), + Args: cobra.ExactArgs(1), + RunE: runQuery, +} + +func runQuery(cmd *cobra.Command, args []string) error { + question := args[0] + ctx := context.Background() + + // Connect to Qdrant + qdrantClient, err := rag.NewQdrantClient(rag.QdrantConfig{ + Host: qdrantHost, + Port: qdrantPort, + UseTLS: false, + }) + if err != nil { + return fmt.Errorf("failed to connect to Qdrant: %w", err) + } + defer qdrantClient.Close() + + // Connect to Ollama + ollamaClient, err := rag.NewOllamaClient(rag.OllamaConfig{ + Host: ollamaHost, + Port: ollamaPort, + Model: model, + }) + if err != nil { + return fmt.Errorf("failed to connect to Ollama: %w", err) + } + + // Configure query + cfg := rag.QueryConfig{ + Collection: queryCollection, + Limit: uint64(limit), + Threshold: threshold, + Category: category, + } + + // Run query + results, err := rag.Query(ctx, qdrantClient, ollamaClient, question, cfg) + if err != nil { + return err + } + + // Format output + switch format { + case "json": + fmt.Println(rag.FormatResultsJSON(results)) + case "context": + fmt.Println(rag.FormatResultsContext(results)) + default: + fmt.Println(rag.FormatResultsText(results)) + } + + return nil +} + +// QueryDocs is exported for use by other packages (e.g., MCP). +func QueryDocs(ctx context.Context, question, collectionName string, topK int) ([]rag.QueryResult, error) { + qdrantClient, err := rag.NewQdrantClient(rag.DefaultQdrantConfig()) + if err != nil { + return nil, err + } + defer qdrantClient.Close() + + ollamaClient, err := rag.NewOllamaClient(rag.DefaultOllamaConfig()) + if err != nil { + return nil, err + } + + cfg := rag.DefaultQueryConfig() + cfg.Collection = collectionName + cfg.Limit = uint64(topK) + + return rag.Query(ctx, qdrantClient, ollamaClient, question, cfg) +} + +// QueryDocsContext is exported and returns context-formatted results. +func QueryDocsContext(ctx context.Context, question, collectionName string, topK int) (string, error) { + results, err := QueryDocs(ctx, question, collectionName, topK) + if err != nil { + return "", err + } + return rag.FormatResultsContext(results), nil +} diff --git a/internal/cmd/rag/cmd_rag.go b/internal/cmd/rag/cmd_rag.go new file mode 100644 index 00000000..a272c448 --- /dev/null +++ b/internal/cmd/rag/cmd_rag.go @@ -0,0 +1,54 @@ +package rag + +import ( + "github.com/host-uk/core/pkg/i18n" + "github.com/spf13/cobra" +) + +// Shared flags +var ( + qdrantHost string + qdrantPort int + ollamaHost string + ollamaPort int + model string + verbose bool +) + +var ragCmd = &cobra.Command{ + Use: "rag", + Short: i18n.T("cmd.rag.short"), + Long: i18n.T("cmd.rag.long"), +} + +func initFlags() { + // Qdrant connection flags (persistent) - defaults to localhost for local development + ragCmd.PersistentFlags().StringVar(&qdrantHost, "qdrant-host", "localhost", i18n.T("cmd.rag.flag.qdrant_host")) + ragCmd.PersistentFlags().IntVar(&qdrantPort, "qdrant-port", 6334, i18n.T("cmd.rag.flag.qdrant_port")) + + // Ollama connection flags (persistent) - defaults to localhost for local development + ragCmd.PersistentFlags().StringVar(&ollamaHost, "ollama-host", "localhost", i18n.T("cmd.rag.flag.ollama_host")) + ragCmd.PersistentFlags().IntVar(&ollamaPort, "ollama-port", 11434, i18n.T("cmd.rag.flag.ollama_port")) + ragCmd.PersistentFlags().StringVar(&model, "model", "nomic-embed-text", i18n.T("cmd.rag.flag.model")) + + // Verbose flag (persistent) + ragCmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, i18n.T("common.flag.verbose")) + + // Ingest command flags + ingestCmd.Flags().StringVar(&collection, "collection", "hostuk-docs", i18n.T("cmd.rag.ingest.flag.collection")) + ingestCmd.Flags().BoolVar(&recreate, "recreate", false, i18n.T("cmd.rag.ingest.flag.recreate")) + ingestCmd.Flags().IntVar(&chunkSize, "chunk-size", 500, i18n.T("cmd.rag.ingest.flag.chunk_size")) + ingestCmd.Flags().IntVar(&chunkOverlap, "chunk-overlap", 50, i18n.T("cmd.rag.ingest.flag.chunk_overlap")) + + // Query command flags + queryCmd.Flags().StringVar(&queryCollection, "collection", "hostuk-docs", i18n.T("cmd.rag.query.flag.collection")) + queryCmd.Flags().IntVar(&limit, "top", 5, i18n.T("cmd.rag.query.flag.top")) + queryCmd.Flags().Float32Var(&threshold, "threshold", 0.5, i18n.T("cmd.rag.query.flag.threshold")) + queryCmd.Flags().StringVar(&category, "category", "", i18n.T("cmd.rag.query.flag.category")) + queryCmd.Flags().StringVar(&format, "format", "text", i18n.T("cmd.rag.query.flag.format")) + + // Collections command flags + collectionsCmd.Flags().BoolVar(&listCollections, "list", false, i18n.T("cmd.rag.collections.flag.list")) + collectionsCmd.Flags().BoolVar(&showStats, "stats", false, i18n.T("cmd.rag.collections.flag.stats")) + collectionsCmd.Flags().StringVar(&deleteCollection, "delete", "", i18n.T("cmd.rag.collections.flag.delete")) +} diff --git a/internal/variants/full.go b/internal/variants/full.go index 409f8782..2a72a48c 100644 --- a/internal/variants/full.go +++ b/internal/variants/full.go @@ -20,6 +20,7 @@ // - test: Test runner with coverage // - qa: Quality assurance workflows // - monitor: Security monitoring aggregation +// - rag: RAG (Retrieval Augmented Generation) tools package variants @@ -37,6 +38,7 @@ import ( _ "github.com/host-uk/core/internal/cmd/php" _ "github.com/host-uk/core/internal/cmd/pkgcmd" _ "github.com/host-uk/core/internal/cmd/qa" + _ "github.com/host-uk/core/internal/cmd/rag" _ "github.com/host-uk/core/internal/cmd/sdk" _ "github.com/host-uk/core/internal/cmd/security" _ "github.com/host-uk/core/internal/cmd/setup" diff --git a/pkg/i18n/locales/en_GB.json b/pkg/i18n/locales/en_GB.json index d85f4528..347f8405 100644 --- a/pkg/i18n/locales/en_GB.json +++ b/pkg/i18n/locales/en_GB.json @@ -608,6 +608,33 @@ "no_findings": "No security findings", "error.no_repos": "No repositories to scan. Use --repo, --all, or run from a git repo", "error.not_git_repo": "Not in a git repository. Use --repo to specify one" + }, + "rag": { + "short": "RAG (Retrieval Augmented Generation) tools", + "long": "RAG tools for storing documentation in Qdrant vector database and querying with semantic search. Eliminates need to repeatedly remind Claude about project specifics.", + "flag.qdrant_host": "Qdrant server hostname", + "flag.qdrant_port": "Qdrant gRPC port", + "flag.ollama_host": "Ollama server hostname", + "flag.ollama_port": "Ollama server port", + "flag.model": "Embedding model name", + "ingest.short": "Ingest markdown files into Qdrant", + "ingest.long": "Ingest markdown files from a directory into Qdrant vector database. Chunks files, generates embeddings via Ollama, and stores for semantic search.", + "ingest.flag.collection": "Qdrant collection name", + "ingest.flag.recreate": "Delete and recreate collection", + "ingest.flag.chunk_size": "Characters per chunk", + "ingest.flag.chunk_overlap": "Overlap between chunks", + "query.short": "Query the vector database", + "query.long": "Search for similar documents using semantic similarity. Returns relevant chunks ranked by score.", + "query.flag.collection": "Qdrant collection name", + "query.flag.top": "Number of results to return", + "query.flag.threshold": "Minimum similarity score (0-1)", + "query.flag.category": "Filter by category", + "query.flag.format": "Output format (text, json, context)", + "collections.short": "List and manage collections", + "collections.long": "List available collections, show statistics, or delete collections from Qdrant.", + "collections.flag.list": "List all collections", + "collections.flag.stats": "Show collection statistics", + "collections.flag.delete": "Delete a collection" } }, "common": { diff --git a/pkg/rag/chunk.go b/pkg/rag/chunk.go new file mode 100644 index 00000000..c0c469f3 --- /dev/null +++ b/pkg/rag/chunk.go @@ -0,0 +1,197 @@ +package rag + +import ( + "crypto/md5" + "fmt" + "path/filepath" + "slices" + "strings" +) + +// ChunkConfig holds chunking configuration. +type ChunkConfig struct { + Size int // Characters per chunk + Overlap int // Overlap between chunks +} + +// DefaultChunkConfig returns default chunking configuration. +func DefaultChunkConfig() ChunkConfig { + return ChunkConfig{ + Size: 500, + Overlap: 50, + } +} + +// Chunk represents a text chunk with metadata. +type Chunk struct { + Text string + Section string + Index int +} + +// ChunkMarkdown splits markdown text into chunks by sections and paragraphs. +// Preserves context with configurable overlap. +func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk { + var chunks []Chunk + + // Split by ## headers + sections := splitBySections(text) + + chunkIndex := 0 + for _, section := range sections { + section = strings.TrimSpace(section) + if section == "" { + continue + } + + // Extract section title + lines := strings.SplitN(section, "\n", 2) + title := "" + if strings.HasPrefix(lines[0], "#") { + title = strings.TrimLeft(lines[0], "#") + title = strings.TrimSpace(title) + } + + // If section is small enough, yield as-is + if len(section) <= cfg.Size { + chunks = append(chunks, Chunk{ + Text: section, + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + continue + } + + // Otherwise, chunk by paragraphs + paragraphs := splitByParagraphs(section) + currentChunk := "" + + for _, para := range paragraphs { + para = strings.TrimSpace(para) + if para == "" { + continue + } + + if len(currentChunk)+len(para)+2 <= cfg.Size { + if currentChunk != "" { + currentChunk += "\n\n" + para + } else { + currentChunk = para + } + } else { + if currentChunk != "" { + chunks = append(chunks, Chunk{ + Text: strings.TrimSpace(currentChunk), + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + } + // Start new chunk with overlap from previous (rune-safe for UTF-8) + runes := []rune(currentChunk) + if cfg.Overlap > 0 && len(runes) > cfg.Overlap { + overlapText := string(runes[len(runes)-cfg.Overlap:]) + currentChunk = overlapText + "\n\n" + para + } else { + currentChunk = para + } + } + } + + // Don't forget the last chunk + if strings.TrimSpace(currentChunk) != "" { + chunks = append(chunks, Chunk{ + Text: strings.TrimSpace(currentChunk), + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + } + } + + return chunks +} + +// splitBySections splits text by ## headers while preserving the header with its content. +func splitBySections(text string) []string { + var sections []string + lines := strings.Split(text, "\n") + + var currentSection strings.Builder + for _, line := range lines { + // Check if this line is a ## header + if strings.HasPrefix(line, "## ") { + // Save previous section if exists + if currentSection.Len() > 0 { + sections = append(sections, currentSection.String()) + currentSection.Reset() + } + } + currentSection.WriteString(line) + currentSection.WriteString("\n") + } + + // Don't forget the last section + if currentSection.Len() > 0 { + sections = append(sections, currentSection.String()) + } + + return sections +} + +// splitByParagraphs splits text by double newlines. +func splitByParagraphs(text string) []string { + // Replace multiple newlines with a marker, then split + normalized := text + for strings.Contains(normalized, "\n\n\n") { + normalized = strings.ReplaceAll(normalized, "\n\n\n", "\n\n") + } + return strings.Split(normalized, "\n\n") +} + +// Category determines the document category from file path. +func Category(path string) string { + lower := strings.ToLower(path) + + switch { + case strings.Contains(lower, "flux") || strings.Contains(lower, "ui/component"): + return "ui-component" + case strings.Contains(lower, "brand") || strings.Contains(lower, "mascot"): + return "brand" + case strings.Contains(lower, "brief"): + return "product-brief" + case strings.Contains(lower, "help") || strings.Contains(lower, "draft"): + return "help-doc" + case strings.Contains(lower, "task") || strings.Contains(lower, "plan"): + return "task" + case strings.Contains(lower, "architecture") || strings.Contains(lower, "migration"): + return "architecture" + default: + return "documentation" + } +} + +// ChunkID generates a unique ID for a chunk. +func ChunkID(path string, index int, text string) string { + // Use first 100 runes of text for uniqueness (rune-safe for UTF-8) + runes := []rune(text) + if len(runes) > 100 { + runes = runes[:100] + } + textPart := string(runes) + data := fmt.Sprintf("%s:%d:%s", path, index, textPart) + hash := md5.Sum([]byte(data)) + return fmt.Sprintf("%x", hash) +} + +// FileExtensions returns the file extensions to process. +func FileExtensions() []string { + return []string{".md", ".markdown", ".txt"} +} + +// ShouldProcess checks if a file should be processed based on extension. +func ShouldProcess(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + return slices.Contains(FileExtensions(), ext) +} diff --git a/pkg/rag/chunk_test.go b/pkg/rag/chunk_test.go new file mode 100644 index 00000000..87fd5c06 --- /dev/null +++ b/pkg/rag/chunk_test.go @@ -0,0 +1,120 @@ +package rag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestChunkMarkdown_Good_SmallSection(t *testing.T) { + text := `# Title + +This is a small section that fits in one chunk. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Contains(t, chunks[0].Text, "small section") +} + +func TestChunkMarkdown_Good_MultipleSections(t *testing.T) { + text := `# Main Title + +Introduction paragraph. + +## Section One + +Content for section one. + +## Section Two + +Content for section two. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.GreaterOrEqual(t, len(chunks), 2) +} + +func TestChunkMarkdown_Good_LargeSection(t *testing.T) { + // Create a section larger than chunk size + text := `## Large Section + +` + repeatString("This is a test paragraph with some content. ", 50) + + cfg := ChunkConfig{Size: 200, Overlap: 20} + chunks := ChunkMarkdown(text, cfg) + + assert.Greater(t, len(chunks), 1) + for _, chunk := range chunks { + assert.NotEmpty(t, chunk.Text) + assert.Equal(t, "Large Section", chunk.Section) + } +} + +func TestChunkMarkdown_Good_ExtractsTitle(t *testing.T) { + text := `## My Section Title + +Some content here. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Equal(t, "My Section Title", chunks[0].Section) +} + +func TestCategory_Good_UIComponent(t *testing.T) { + tests := []struct { + path string + expected string + }{ + {"docs/flux/button.md", "ui-component"}, + {"ui/components/modal.md", "ui-component"}, + {"brand/vi-personality.md", "brand"}, + {"mascot/expressions.md", "brand"}, + {"product-brief.md", "product-brief"}, + {"tasks/2024-01-15-feature.md", "task"}, + {"plans/architecture.md", "task"}, + {"architecture/migration.md", "architecture"}, + {"docs/api.md", "documentation"}, + } + + for _, tc := range tests { + t.Run(tc.path, func(t *testing.T) { + assert.Equal(t, tc.expected, Category(tc.path)) + }) + } +} + +func TestChunkID_Good_Deterministic(t *testing.T) { + id1 := ChunkID("test.md", 0, "hello world") + id2 := ChunkID("test.md", 0, "hello world") + + assert.Equal(t, id1, id2) +} + +func TestChunkID_Good_DifferentForDifferentInputs(t *testing.T) { + id1 := ChunkID("test.md", 0, "hello world") + id2 := ChunkID("test.md", 1, "hello world") + id3 := ChunkID("other.md", 0, "hello world") + + assert.NotEqual(t, id1, id2) + assert.NotEqual(t, id1, id3) +} + +func TestShouldProcess_Good_MarkdownFiles(t *testing.T) { + assert.True(t, ShouldProcess("doc.md")) + assert.True(t, ShouldProcess("doc.markdown")) + assert.True(t, ShouldProcess("doc.txt")) + assert.False(t, ShouldProcess("doc.go")) + assert.False(t, ShouldProcess("doc.py")) + assert.False(t, ShouldProcess("doc")) +} + +// Helper function +func repeatString(s string, n int) string { + result := "" + for i := 0; i < n; i++ { + result += s + } + return result +} diff --git a/pkg/rag/ingest.go b/pkg/rag/ingest.go new file mode 100644 index 00000000..416a9354 --- /dev/null +++ b/pkg/rag/ingest.go @@ -0,0 +1,214 @@ +package rag + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" +) + +// IngestConfig holds ingestion configuration. +type IngestConfig struct { + Directory string + Collection string + Recreate bool + Verbose bool + BatchSize int + Chunk ChunkConfig +} + +// DefaultIngestConfig returns default ingestion configuration. +func DefaultIngestConfig() IngestConfig { + return IngestConfig{ + Collection: "hostuk-docs", + BatchSize: 100, + Chunk: DefaultChunkConfig(), + } +} + +// IngestStats holds statistics from ingestion. +type IngestStats struct { + Files int + Chunks int + Errors int +} + +// IngestProgress is called during ingestion to report progress. +type IngestProgress func(file string, chunks int, total int) + +// Ingest processes a directory of documents and stores them in Qdrant. +func Ingest(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, cfg IngestConfig, progress IngestProgress) (*IngestStats, error) { + stats := &IngestStats{} + + // Validate batch size to prevent infinite loop + if cfg.BatchSize <= 0 { + cfg.BatchSize = 100 // Safe default + } + + // Resolve directory + absDir, err := filepath.Abs(cfg.Directory) + if err != nil { + return nil, fmt.Errorf("error resolving directory: %w", err) + } + + info, err := os.Stat(absDir) + if err != nil { + return nil, fmt.Errorf("error accessing directory: %w", err) + } + if !info.IsDir() { + return nil, fmt.Errorf("not a directory: %s", absDir) + } + + // Check/create collection + exists, err := qdrant.CollectionExists(ctx, cfg.Collection) + if err != nil { + return nil, fmt.Errorf("error checking collection: %w", err) + } + + if cfg.Recreate && exists { + if err := qdrant.DeleteCollection(ctx, cfg.Collection); err != nil { + return nil, fmt.Errorf("error deleting collection: %w", err) + } + exists = false + } + + if !exists { + vectorDim := ollama.EmbedDimension() + if err := qdrant.CreateCollection(ctx, cfg.Collection, vectorDim); err != nil { + return nil, fmt.Errorf("error creating collection: %w", err) + } + } + + // Find markdown files + var files []string + err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() && ShouldProcess(path) { + files = append(files, path) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("error walking directory: %w", err) + } + + if len(files) == 0 { + return nil, fmt.Errorf("no markdown files found in %s", absDir) + } + + // Process files + var points []Point + for _, filePath := range files { + relPath, err := filepath.Rel(absDir, filePath) + if err != nil { + stats.Errors++ + continue + } + + content, err := os.ReadFile(filePath) + if err != nil { + stats.Errors++ + continue + } + + if len(strings.TrimSpace(string(content))) == 0 { + continue + } + + // Chunk the content + category := Category(relPath) + chunks := ChunkMarkdown(string(content), cfg.Chunk) + + for _, chunk := range chunks { + // Generate embedding + embedding, err := ollama.Embed(ctx, chunk.Text) + if err != nil { + stats.Errors++ + if cfg.Verbose { + fmt.Printf(" Error embedding %s chunk %d: %v\n", relPath, chunk.Index, err) + } + continue + } + + // Create point + points = append(points, Point{ + ID: ChunkID(relPath, chunk.Index, chunk.Text), + Vector: embedding, + Payload: map[string]any{ + "text": chunk.Text, + "source": relPath, + "section": chunk.Section, + "category": category, + "chunk_index": chunk.Index, + }, + }) + stats.Chunks++ + } + + stats.Files++ + if progress != nil { + progress(relPath, stats.Chunks, len(files)) + } + } + + // Batch upsert to Qdrant + if len(points) > 0 { + for i := 0; i < len(points); i += cfg.BatchSize { + end := i + cfg.BatchSize + if end > len(points) { + end = len(points) + } + batch := points[i:end] + if err := qdrant.UpsertPoints(ctx, cfg.Collection, batch); err != nil { + return stats, fmt.Errorf("error upserting batch %d: %w", i/cfg.BatchSize+1, err) + } + } + } + + return stats, nil +} + +// IngestFile processes a single file and stores it in Qdrant. +func IngestFile(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, collection string, filePath string, chunkCfg ChunkConfig) (int, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return 0, fmt.Errorf("error reading file: %w", err) + } + + if len(strings.TrimSpace(string(content))) == 0 { + return 0, nil + } + + category := Category(filePath) + chunks := ChunkMarkdown(string(content), chunkCfg) + + var points []Point + for _, chunk := range chunks { + embedding, err := ollama.Embed(ctx, chunk.Text) + if err != nil { + return 0, fmt.Errorf("error embedding chunk %d: %w", chunk.Index, err) + } + + points = append(points, Point{ + ID: ChunkID(filePath, chunk.Index, chunk.Text), + Vector: embedding, + Payload: map[string]any{ + "text": chunk.Text, + "source": filePath, + "section": chunk.Section, + "category": category, + "chunk_index": chunk.Index, + }, + }) + } + + if err := qdrant.UpsertPoints(ctx, collection, points); err != nil { + return 0, fmt.Errorf("error upserting points: %w", err) + } + + return len(points), nil +} diff --git a/pkg/rag/ollama.go b/pkg/rag/ollama.go new file mode 100644 index 00000000..70510425 --- /dev/null +++ b/pkg/rag/ollama.go @@ -0,0 +1,116 @@ +package rag + +import ( + "context" + "fmt" + "net/http" + "net/url" + + "github.com/ollama/ollama/api" +) + +// OllamaConfig holds Ollama connection configuration. +type OllamaConfig struct { + Host string + Port int + Model string +} + +// DefaultOllamaConfig returns default Ollama configuration. +// Host defaults to localhost for local development. +func DefaultOllamaConfig() OllamaConfig { + return OllamaConfig{ + Host: "localhost", + Port: 11434, + Model: "nomic-embed-text", + } +} + +// OllamaClient wraps the Ollama API client for embeddings. +type OllamaClient struct { + client *api.Client + config OllamaConfig +} + +// NewOllamaClient creates a new Ollama client. +func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) { + baseURL := &url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + } + + client := api.NewClient(baseURL, http.DefaultClient) + + return &OllamaClient{ + client: client, + config: cfg, + }, nil +} + +// EmbedDimension returns the embedding dimension for the configured model. +// nomic-embed-text uses 768 dimensions. +func (o *OllamaClient) EmbedDimension() uint64 { + switch o.config.Model { + case "nomic-embed-text": + return 768 + case "mxbai-embed-large": + return 1024 + case "all-minilm": + return 384 + default: + return 768 // Default to nomic-embed-text dimension + } +} + +// Embed generates embeddings for the given text. +func (o *OllamaClient) Embed(ctx context.Context, text string) ([]float32, error) { + req := &api.EmbedRequest{ + Model: o.config.Model, + Input: text, + } + + resp, err := o.client.Embed(ctx, req) + if err != nil { + return nil, fmt.Errorf("failed to generate embedding: %w", err) + } + + if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 { + return nil, fmt.Errorf("empty embedding response") + } + + // Convert float64 to float32 for Qdrant + embedding := resp.Embeddings[0] + result := make([]float32, len(embedding)) + for i, v := range embedding { + result[i] = float32(v) + } + + return result, nil +} + +// EmbedBatch generates embeddings for multiple texts. +func (o *OllamaClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + results := make([][]float32, len(texts)) + for i, text := range texts { + embedding, err := o.Embed(ctx, text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + results[i] = embedding + } + return results, nil +} + +// VerifyModel checks if the embedding model is available. +func (o *OllamaClient) VerifyModel(ctx context.Context) error { + _, err := o.Embed(ctx, "test") + if err != nil { + return fmt.Errorf("model %s not available: %w (run: ollama pull %s)", o.config.Model, err, o.config.Model) + } + return nil +} + +// Model returns the configured embedding model name. +func (o *OllamaClient) Model() string { + return o.config.Model +} diff --git a/pkg/rag/qdrant.go b/pkg/rag/qdrant.go new file mode 100644 index 00000000..6f359db5 --- /dev/null +++ b/pkg/rag/qdrant.go @@ -0,0 +1,224 @@ +// Package rag provides RAG (Retrieval Augmented Generation) functionality +// for storing and querying documentation in Qdrant vector database. +package rag + +import ( + "context" + "fmt" + + "github.com/qdrant/go-client/qdrant" +) + +// QdrantConfig holds Qdrant connection configuration. +type QdrantConfig struct { + Host string + Port int + APIKey string + UseTLS bool +} + +// DefaultQdrantConfig returns default Qdrant configuration. +// Host defaults to localhost for local development. +func DefaultQdrantConfig() QdrantConfig { + return QdrantConfig{ + Host: "localhost", + Port: 6334, // gRPC port + UseTLS: false, + } +} + +// QdrantClient wraps the Qdrant Go client with convenience methods. +type QdrantClient struct { + client *qdrant.Client + config QdrantConfig +} + +// NewQdrantClient creates a new Qdrant client. +func NewQdrantClient(cfg QdrantConfig) (*QdrantClient, error) { + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + + client, err := qdrant.NewClient(&qdrant.Config{ + Host: cfg.Host, + Port: cfg.Port, + APIKey: cfg.APIKey, + UseTLS: cfg.UseTLS, + }) + if err != nil { + return nil, fmt.Errorf("failed to connect to Qdrant at %s: %w", addr, err) + } + + return &QdrantClient{ + client: client, + config: cfg, + }, nil +} + +// Close closes the Qdrant client connection. +func (q *QdrantClient) Close() error { + return q.client.Close() +} + +// HealthCheck verifies the connection to Qdrant. +func (q *QdrantClient) HealthCheck(ctx context.Context) error { + _, err := q.client.HealthCheck(ctx) + return err +} + +// ListCollections returns all collection names. +func (q *QdrantClient) ListCollections(ctx context.Context) ([]string, error) { + resp, err := q.client.ListCollections(ctx) + if err != nil { + return nil, err + } + names := make([]string, len(resp)) + copy(names, resp) + return names, nil +} + +// CollectionExists checks if a collection exists. +func (q *QdrantClient) CollectionExists(ctx context.Context, name string) (bool, error) { + return q.client.CollectionExists(ctx, name) +} + +// CreateCollection creates a new collection with cosine distance. +func (q *QdrantClient) CreateCollection(ctx context.Context, name string, vectorSize uint64) error { + return q.client.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: name, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + }) +} + +// DeleteCollection deletes a collection. +func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error { + return q.client.DeleteCollection(ctx, name) +} + +// CollectionInfo returns information about a collection. +func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) { + return q.client.GetCollectionInfo(ctx, name) +} + +// Point represents a vector point with payload. +type Point struct { + ID string + Vector []float32 + Payload map[string]any +} + +// UpsertPoints inserts or updates points in a collection. +func (q *QdrantClient) UpsertPoints(ctx context.Context, collection string, points []Point) error { + if len(points) == 0 { + return nil + } + + qdrantPoints := make([]*qdrant.PointStruct, len(points)) + for i, p := range points { + qdrantPoints[i] = &qdrant.PointStruct{ + Id: qdrant.NewID(p.ID), + Vectors: qdrant.NewVectors(p.Vector...), + Payload: qdrant.NewValueMap(p.Payload), + } + } + + _, err := q.client.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: collection, + Points: qdrantPoints, + }) + return err +} + +// SearchResult represents a search result with score. +type SearchResult struct { + ID string + Score float32 + Payload map[string]any +} + +// Search performs a vector similarity search. +func (q *QdrantClient) Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error) { + query := &qdrant.QueryPoints{ + CollectionName: collection, + Query: qdrant.NewQuery(vector...), + Limit: qdrant.PtrOf(limit), + WithPayload: qdrant.NewWithPayload(true), + } + + // Add filter if provided + if len(filter) > 0 { + conditions := make([]*qdrant.Condition, 0, len(filter)) + for k, v := range filter { + conditions = append(conditions, qdrant.NewMatch(k, v)) + } + query.Filter = &qdrant.Filter{ + Must: conditions, + } + } + + resp, err := q.client.Query(ctx, query) + if err != nil { + return nil, err + } + + results := make([]SearchResult, len(resp)) + for i, p := range resp { + payload := make(map[string]any) + for k, v := range p.Payload { + payload[k] = valueToGo(v) + } + results[i] = SearchResult{ + ID: pointIDToString(p.Id), + Score: p.Score, + Payload: payload, + } + } + return results, nil +} + +// pointIDToString converts a Qdrant point ID to string. +func pointIDToString(id *qdrant.PointId) string { + if id == nil { + return "" + } + switch v := id.PointIdOptions.(type) { + case *qdrant.PointId_Num: + return fmt.Sprintf("%d", v.Num) + case *qdrant.PointId_Uuid: + return v.Uuid + default: + return "" + } +} + +// valueToGo converts a Qdrant value to a Go value. +func valueToGo(v *qdrant.Value) any { + if v == nil { + return nil + } + switch val := v.Kind.(type) { + case *qdrant.Value_StringValue: + return val.StringValue + case *qdrant.Value_IntegerValue: + return val.IntegerValue + case *qdrant.Value_DoubleValue: + return val.DoubleValue + case *qdrant.Value_BoolValue: + return val.BoolValue + case *qdrant.Value_ListValue: + list := make([]any, len(val.ListValue.Values)) + for i, item := range val.ListValue.Values { + list[i] = valueToGo(item) + } + return list + case *qdrant.Value_StructValue: + m := make(map[string]any) + for k, item := range val.StructValue.Fields { + m[k] = valueToGo(item) + } + return m + default: + return nil + } +} diff --git a/pkg/rag/query.go b/pkg/rag/query.go new file mode 100644 index 00000000..20e7f143 --- /dev/null +++ b/pkg/rag/query.go @@ -0,0 +1,161 @@ +package rag + +import ( + "context" + "fmt" + "html" + "strings" +) + +// QueryConfig holds query configuration. +type QueryConfig struct { + Collection string + Limit uint64 + Threshold float32 // Minimum similarity score (0-1) + Category string // Filter by category +} + +// DefaultQueryConfig returns default query configuration. +func DefaultQueryConfig() QueryConfig { + return QueryConfig{ + Collection: "hostuk-docs", + Limit: 5, + Threshold: 0.5, + } +} + +// QueryResult represents a query result with metadata. +type QueryResult struct { + Text string + Source string + Section string + Category string + ChunkIndex int + Score float32 +} + +// Query searches for similar documents in Qdrant. +func Query(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, query string, cfg QueryConfig) ([]QueryResult, error) { + // Generate embedding for query + embedding, err := ollama.Embed(ctx, query) + if err != nil { + return nil, fmt.Errorf("error generating query embedding: %w", err) + } + + // Build filter + var filter map[string]string + if cfg.Category != "" { + filter = map[string]string{"category": cfg.Category} + } + + // Search Qdrant + results, err := qdrant.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter) + if err != nil { + return nil, fmt.Errorf("error searching: %w", err) + } + + // Convert and filter by threshold + var queryResults []QueryResult + for _, r := range results { + if r.Score < cfg.Threshold { + continue + } + + qr := QueryResult{ + Score: r.Score, + } + + // Extract payload fields + if text, ok := r.Payload["text"].(string); ok { + qr.Text = text + } + if source, ok := r.Payload["source"].(string); ok { + qr.Source = source + } + if section, ok := r.Payload["section"].(string); ok { + qr.Section = section + } + if category, ok := r.Payload["category"].(string); ok { + qr.Category = category + } + // Handle chunk_index from various types (JSON unmarshaling produces float64) + switch idx := r.Payload["chunk_index"].(type) { + case int64: + qr.ChunkIndex = int(idx) + case float64: + qr.ChunkIndex = int(idx) + case int: + qr.ChunkIndex = idx + } + + queryResults = append(queryResults, qr) + } + + return queryResults, nil +} + +// FormatResultsText formats query results as plain text. +func FormatResultsText(results []QueryResult) string { + if len(results) == 0 { + return "No results found." + } + + var sb strings.Builder + for i, r := range results { + sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score)) + sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source)) + if r.Section != "" { + sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section)) + } + sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category)) + sb.WriteString(r.Text) + sb.WriteString("\n") + } + return sb.String() +} + +// FormatResultsContext formats query results for LLM context injection. +func FormatResultsContext(results []QueryResult) string { + if len(results) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("\n") + for _, r := range results { + // Escape XML special characters to prevent malformed output + fmt.Fprintf(&sb, "\n", + html.EscapeString(r.Source), + html.EscapeString(r.Section), + html.EscapeString(r.Category)) + sb.WriteString(html.EscapeString(r.Text)) + sb.WriteString("\n\n\n") + } + sb.WriteString("") + return sb.String() +} + +// FormatResultsJSON formats query results as JSON-like output. +func FormatResultsJSON(results []QueryResult) string { + if len(results) == 0 { + return "[]" + } + + var sb strings.Builder + sb.WriteString("[\n") + for i, r := range results { + sb.WriteString(" {\n") + sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source)) + sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section)) + sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category)) + sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score)) + sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text)) + if i < len(results)-1 { + sb.WriteString(" },\n") + } else { + sb.WriteString(" }\n") + } + } + sb.WriteString("]") + return sb.String() +}