feat: extract go-rag from go-ai as standalone RAG package

Vector search with Qdrant + Ollama embeddings, document chunking.
Zero internal go-ai dependencies. Adds CLAUDE.md/TODO.md/FINDINGS.md.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-19 18:29:59 +00:00
commit bf047e4494
11 changed files with 1252 additions and 0 deletions

52
CLAUDE.md Normal file
View file

@ -0,0 +1,52 @@
# CLAUDE.md
## What This Is
Retrieval-Augmented Generation with vector search. Module: `forge.lthn.ai/core/go-rag`
Provides document chunking, embedding via Ollama, vector storage/search via Qdrant, and formatted context retrieval for AI prompts.
## Commands
```bash
go test ./... # Run all tests
go test -v -run TestChunk # Single test
```
## Architecture
| File | Purpose |
|------|---------|
| `chunk.go` | Document chunking — splits markdown/text into semantic chunks |
| `ingest.go` | Ingestion pipeline — reads files, chunks, embeds, stores |
| `query.go` | Query interface — search vectors, format results as text/JSON/XML |
| `qdrant.go` | Qdrant vector DB client — create collections, upsert, search |
| `ollama.go` | Ollama embedding client — generate embeddings for chunks |
| `helpers.go` | Shared utilities |
## Dependencies
- `forge.lthn.ai/core/go` — Logging (pkg/log)
- `github.com/ollama/ollama` — Embedding API client
- `github.com/qdrant/go-client` — Vector DB gRPC client
- `github.com/stretchr/testify` — Tests
## Key API
```go
// Ingest documents
rag.IngestFile(ctx, cfg, "/path/to/doc.md")
rag.Ingest(ctx, cfg, reader, "source-name")
// Query for relevant context
results, err := rag.Query(ctx, cfg, "search query")
context := rag.FormatResults(results, "text") // or "json", "xml"
```
## Coding Standards
- UK English
- Tests: testify assert/require
- Conventional commits
- Co-Author: `Co-Authored-By: Virgil <virgil@lethean.io>`
- Licence: EUPL-1.2

21
FINDINGS.md Normal file
View file

@ -0,0 +1,21 @@
# FINDINGS.md — go-rag Research & Discovery
## 2026-02-19: Split from go-ai (Virgil)
### Origin
Extracted from `forge.lthn.ai/core/go-ai/rag/`. Zero internal go-ai dependencies.
### What Was Extracted
- 7 Go files (~1,017 LOC excluding tests)
- 1 test file (chunk_test.go)
### Key Finding: Minimal Test Coverage
Only chunk.go has tests. The Qdrant and Ollama clients are untested — they depend on external services (Qdrant server, Ollama API) which makes unit testing harder. Consider mock interfaces.
### Consumers
- `go-ai/ai/rag.go` wraps this as `QueryRAGForTask()` facade
- `go-ai/mcp/tools_rag.go` exposes RAG as MCP tools

22
TODO.md Normal file
View file

@ -0,0 +1,22 @@
# TODO.md — go-rag Task Queue
## Phase 1: Post-Split
- [ ] **Verify tests pass standalone** — Run `go test ./...`. Confirm chunk_test.go passes.
- [ ] **Add missing tests** — Only chunk.go has tests. Need: query, ingest, qdrant client, ollama client.
- [ ] **Benchmark chunking** — No perf baselines. Add BenchmarkChunk for various document sizes.
## Phase 2: Enhancements
- [ ] **Embedding model selection** — Currently hardcoded to Ollama. Add backend interface for alternative embedding providers.
- [ ] **Collection management** — Add list/delete collection operations for Qdrant.
- [ ] **Hybrid search** — Combine vector similarity with keyword matching for better recall.
- [ ] **Chunk overlap** — Current chunking may lose context at boundaries. Add configurable overlap.
---
## Workflow
1. Virgil in core/go writes tasks here after research
2. This repo's session picks up tasks in phase order
3. Mark `[x]` when done, note commit hash

204
chunk.go Normal file
View file

@ -0,0 +1,204 @@
package rag
import (
"crypto/md5"
"fmt"
"path/filepath"
"slices"
"strings"
)
// ChunkConfig holds chunking configuration.
type ChunkConfig struct {
Size int // Characters per chunk
Overlap int // Overlap between chunks
}
// DefaultChunkConfig returns default chunking configuration.
func DefaultChunkConfig() ChunkConfig {
return ChunkConfig{
Size: 500,
Overlap: 50,
}
}
// Chunk represents a text chunk with metadata.
type Chunk struct {
Text string
Section string
Index int
}
// ChunkMarkdown splits markdown text into chunks by sections and paragraphs.
// Preserves context with configurable overlap.
func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk {
if cfg.Size <= 0 {
cfg.Size = 500
}
if cfg.Overlap < 0 || cfg.Overlap >= cfg.Size {
cfg.Overlap = 0
}
var chunks []Chunk
// Split by ## headers
sections := splitBySections(text)
chunkIndex := 0
for _, section := range sections {
section = strings.TrimSpace(section)
if section == "" {
continue
}
// Extract section title
lines := strings.SplitN(section, "\n", 2)
title := ""
if strings.HasPrefix(lines[0], "#") {
title = strings.TrimLeft(lines[0], "#")
title = strings.TrimSpace(title)
}
// If section is small enough, yield as-is
if len(section) <= cfg.Size {
chunks = append(chunks, Chunk{
Text: section,
Section: title,
Index: chunkIndex,
})
chunkIndex++
continue
}
// Otherwise, chunk by paragraphs
paragraphs := splitByParagraphs(section)
currentChunk := ""
for _, para := range paragraphs {
para = strings.TrimSpace(para)
if para == "" {
continue
}
if len(currentChunk)+len(para)+2 <= cfg.Size {
if currentChunk != "" {
currentChunk += "\n\n" + para
} else {
currentChunk = para
}
} else {
if currentChunk != "" {
chunks = append(chunks, Chunk{
Text: strings.TrimSpace(currentChunk),
Section: title,
Index: chunkIndex,
})
chunkIndex++
}
// Start new chunk with overlap from previous (rune-safe for UTF-8)
runes := []rune(currentChunk)
if cfg.Overlap > 0 && len(runes) > cfg.Overlap {
overlapText := string(runes[len(runes)-cfg.Overlap:])
currentChunk = overlapText + "\n\n" + para
} else {
currentChunk = para
}
}
}
// Don't forget the last chunk
if strings.TrimSpace(currentChunk) != "" {
chunks = append(chunks, Chunk{
Text: strings.TrimSpace(currentChunk),
Section: title,
Index: chunkIndex,
})
chunkIndex++
}
}
return chunks
}
// splitBySections splits text by ## headers while preserving the header with its content.
func splitBySections(text string) []string {
var sections []string
lines := strings.Split(text, "\n")
var currentSection strings.Builder
for _, line := range lines {
// Check if this line is a ## header
if strings.HasPrefix(line, "## ") {
// Save previous section if exists
if currentSection.Len() > 0 {
sections = append(sections, currentSection.String())
currentSection.Reset()
}
}
currentSection.WriteString(line)
currentSection.WriteString("\n")
}
// Don't forget the last section
if currentSection.Len() > 0 {
sections = append(sections, currentSection.String())
}
return sections
}
// splitByParagraphs splits text by double newlines.
func splitByParagraphs(text string) []string {
// Replace multiple newlines with a marker, then split
normalized := text
for strings.Contains(normalized, "\n\n\n") {
normalized = strings.ReplaceAll(normalized, "\n\n\n", "\n\n")
}
return strings.Split(normalized, "\n\n")
}
// Category determines the document category from file path.
func Category(path string) string {
lower := strings.ToLower(path)
switch {
case strings.Contains(lower, "flux") || strings.Contains(lower, "ui/component"):
return "ui-component"
case strings.Contains(lower, "brand") || strings.Contains(lower, "mascot"):
return "brand"
case strings.Contains(lower, "brief"):
return "product-brief"
case strings.Contains(lower, "help") || strings.Contains(lower, "draft"):
return "help-doc"
case strings.Contains(lower, "task") || strings.Contains(lower, "plan"):
return "task"
case strings.Contains(lower, "architecture") || strings.Contains(lower, "migration"):
return "architecture"
default:
return "documentation"
}
}
// ChunkID generates a unique ID for a chunk.
func ChunkID(path string, index int, text string) string {
// Use first 100 runes of text for uniqueness (rune-safe for UTF-8)
runes := []rune(text)
if len(runes) > 100 {
runes = runes[:100]
}
textPart := string(runes)
data := fmt.Sprintf("%s:%d:%s", path, index, textPart)
hash := md5.Sum([]byte(data))
return fmt.Sprintf("%x", hash)
}
// FileExtensions returns the file extensions to process.
func FileExtensions() []string {
return []string{".md", ".markdown", ".txt"}
}
// ShouldProcess checks if a file should be processed based on extension.
func ShouldProcess(path string) bool {
ext := strings.ToLower(filepath.Ext(path))
return slices.Contains(FileExtensions(), ext)
}

120
chunk_test.go Normal file
View file

@ -0,0 +1,120 @@
package rag
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestChunkMarkdown_Good_SmallSection(t *testing.T) {
text := `# Title
This is a small section that fits in one chunk.
`
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.Len(t, chunks, 1)
assert.Contains(t, chunks[0].Text, "small section")
}
func TestChunkMarkdown_Good_MultipleSections(t *testing.T) {
text := `# Main Title
Introduction paragraph.
## Section One
Content for section one.
## Section Two
Content for section two.
`
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.GreaterOrEqual(t, len(chunks), 2)
}
func TestChunkMarkdown_Good_LargeSection(t *testing.T) {
// Create a section larger than chunk size
text := `## Large Section
` + repeatString("This is a test paragraph with some content. ", 50)
cfg := ChunkConfig{Size: 200, Overlap: 20}
chunks := ChunkMarkdown(text, cfg)
assert.Greater(t, len(chunks), 1)
for _, chunk := range chunks {
assert.NotEmpty(t, chunk.Text)
assert.Equal(t, "Large Section", chunk.Section)
}
}
func TestChunkMarkdown_Good_ExtractsTitle(t *testing.T) {
text := `## My Section Title
Some content here.
`
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.Len(t, chunks, 1)
assert.Equal(t, "My Section Title", chunks[0].Section)
}
func TestCategory_Good_UIComponent(t *testing.T) {
tests := []struct {
path string
expected string
}{
{"docs/flux/button.md", "ui-component"},
{"ui/components/modal.md", "ui-component"},
{"brand/vi-personality.md", "brand"},
{"mascot/expressions.md", "brand"},
{"product-brief.md", "product-brief"},
{"tasks/2024-01-15-feature.md", "task"},
{"plans/architecture.md", "task"},
{"architecture/migration.md", "architecture"},
{"docs/api.md", "documentation"},
}
for _, tc := range tests {
t.Run(tc.path, func(t *testing.T) {
assert.Equal(t, tc.expected, Category(tc.path))
})
}
}
func TestChunkID_Good_Deterministic(t *testing.T) {
id1 := ChunkID("test.md", 0, "hello world")
id2 := ChunkID("test.md", 0, "hello world")
assert.Equal(t, id1, id2)
}
func TestChunkID_Good_DifferentForDifferentInputs(t *testing.T) {
id1 := ChunkID("test.md", 0, "hello world")
id2 := ChunkID("test.md", 1, "hello world")
id3 := ChunkID("other.md", 0, "hello world")
assert.NotEqual(t, id1, id2)
assert.NotEqual(t, id1, id3)
}
func TestShouldProcess_Good_MarkdownFiles(t *testing.T) {
assert.True(t, ShouldProcess("doc.md"))
assert.True(t, ShouldProcess("doc.markdown"))
assert.True(t, ShouldProcess("doc.txt"))
assert.False(t, ShouldProcess("doc.go"))
assert.False(t, ShouldProcess("doc.py"))
assert.False(t, ShouldProcess("doc"))
}
// Helper function
func repeatString(s string, n int) string {
result := ""
for i := 0; i < n; i++ {
result += s
}
return result
}

20
go.mod Normal file
View file

@ -0,0 +1,20 @@
module forge.lthn.ai/core/go-rag
go 1.25.5
require (
forge.lthn.ai/core/go v0.0.0
github.com/ollama/ollama v0.16.1
github.com/qdrant/go-client v1.16.2
github.com/stretchr/testify v1.11.1
)
require (
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect
google.golang.org/grpc v1.78.0 // indirect
google.golang.org/protobuf v1.36.11 // indirect
)
replace forge.lthn.ai/core/go => ../core

89
helpers.go Normal file
View file

@ -0,0 +1,89 @@
package rag
import (
"context"
"fmt"
)
// QueryDocs queries the RAG database with default clients.
func QueryDocs(ctx context.Context, question, collectionName string, topK int) ([]QueryResult, error) {
qdrantClient, err := NewQdrantClient(DefaultQdrantConfig())
if err != nil {
return nil, err
}
defer func() { _ = qdrantClient.Close() }()
ollamaClient, err := NewOllamaClient(DefaultOllamaConfig())
if err != nil {
return nil, err
}
cfg := DefaultQueryConfig()
cfg.Collection = collectionName
cfg.Limit = uint64(topK)
return Query(ctx, qdrantClient, ollamaClient, question, cfg)
}
// QueryDocsContext queries the RAG database and returns context-formatted results.
func QueryDocsContext(ctx context.Context, question, collectionName string, topK int) (string, error) {
results, err := QueryDocs(ctx, question, collectionName, topK)
if err != nil {
return "", err
}
return FormatResultsContext(results), nil
}
// IngestDirectory ingests all documents in a directory with default clients.
func IngestDirectory(ctx context.Context, directory, collectionName string, recreateCollection bool) error {
qdrantClient, err := NewQdrantClient(DefaultQdrantConfig())
if err != nil {
return err
}
defer func() { _ = qdrantClient.Close() }()
if err := qdrantClient.HealthCheck(ctx); err != nil {
return fmt.Errorf("qdrant health check failed: %w", err)
}
ollamaClient, err := NewOllamaClient(DefaultOllamaConfig())
if err != nil {
return err
}
if err := ollamaClient.VerifyModel(ctx); err != nil {
return err
}
cfg := DefaultIngestConfig()
cfg.Directory = directory
cfg.Collection = collectionName
cfg.Recreate = recreateCollection
_, err = Ingest(ctx, qdrantClient, ollamaClient, cfg, nil)
return err
}
// IngestSingleFile ingests a single file with default clients.
func IngestSingleFile(ctx context.Context, filePath, collectionName string) (int, error) {
qdrantClient, err := NewQdrantClient(DefaultQdrantConfig())
if err != nil {
return 0, err
}
defer func() { _ = qdrantClient.Close() }()
if err := qdrantClient.HealthCheck(ctx); err != nil {
return 0, fmt.Errorf("qdrant health check failed: %w", err)
}
ollamaClient, err := NewOllamaClient(DefaultOllamaConfig())
if err != nil {
return 0, err
}
if err := ollamaClient.VerifyModel(ctx); err != nil {
return 0, err
}
return IngestFile(ctx, qdrantClient, ollamaClient, collectionName, filePath, DefaultChunkConfig())
}

216
ingest.go Normal file
View file

@ -0,0 +1,216 @@
package rag
import (
"context"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"forge.lthn.ai/core/go/pkg/log"
)
// IngestConfig holds ingestion configuration.
type IngestConfig struct {
Directory string
Collection string
Recreate bool
Verbose bool
BatchSize int
Chunk ChunkConfig
}
// DefaultIngestConfig returns default ingestion configuration.
func DefaultIngestConfig() IngestConfig {
return IngestConfig{
Collection: "hostuk-docs",
BatchSize: 100,
Chunk: DefaultChunkConfig(),
}
}
// IngestStats holds statistics from ingestion.
type IngestStats struct {
Files int
Chunks int
Errors int
}
// IngestProgress is called during ingestion to report progress.
type IngestProgress func(file string, chunks int, total int)
// Ingest processes a directory of documents and stores them in Qdrant.
func Ingest(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, cfg IngestConfig, progress IngestProgress) (*IngestStats, error) {
stats := &IngestStats{}
// Validate batch size to prevent infinite loop
if cfg.BatchSize <= 0 {
cfg.BatchSize = 100 // Safe default
}
// Resolve directory
absDir, err := filepath.Abs(cfg.Directory)
if err != nil {
return nil, log.E("rag.Ingest", "error resolving directory", err)
}
info, err := os.Stat(absDir)
if err != nil {
return nil, log.E("rag.Ingest", "error accessing directory", err)
}
if !info.IsDir() {
return nil, log.E("rag.Ingest", fmt.Sprintf("not a directory: %s", absDir), nil)
}
// Check/create collection
exists, err := qdrant.CollectionExists(ctx, cfg.Collection)
if err != nil {
return nil, log.E("rag.Ingest", "error checking collection", err)
}
if cfg.Recreate && exists {
if err := qdrant.DeleteCollection(ctx, cfg.Collection); err != nil {
return nil, log.E("rag.Ingest", "error deleting collection", err)
}
exists = false
}
if !exists {
vectorDim := ollama.EmbedDimension()
if err := qdrant.CreateCollection(ctx, cfg.Collection, vectorDim); err != nil {
return nil, log.E("rag.Ingest", "error creating collection", err)
}
}
// Find markdown files
var files []string
err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if !d.IsDir() && ShouldProcess(path) {
files = append(files, path)
}
return nil
})
if err != nil {
return nil, log.E("rag.Ingest", "error walking directory", err)
}
if len(files) == 0 {
return nil, log.E("rag.Ingest", fmt.Sprintf("no markdown files found in %s", absDir), nil)
}
// Process files
var points []Point
for _, filePath := range files {
relPath, err := filepath.Rel(absDir, filePath)
if err != nil {
stats.Errors++
continue
}
content, err := os.ReadFile(filePath)
if err != nil {
stats.Errors++
continue
}
if len(strings.TrimSpace(string(content))) == 0 {
continue
}
// Chunk the content
category := Category(relPath)
chunks := ChunkMarkdown(string(content), cfg.Chunk)
for _, chunk := range chunks {
// Generate embedding
embedding, err := ollama.Embed(ctx, chunk.Text)
if err != nil {
stats.Errors++
if cfg.Verbose {
fmt.Printf(" Error embedding %s chunk %d: %v\n", relPath, chunk.Index, err)
}
continue
}
// Create point
points = append(points, Point{
ID: ChunkID(relPath, chunk.Index, chunk.Text),
Vector: embedding,
Payload: map[string]any{
"text": chunk.Text,
"source": relPath,
"section": chunk.Section,
"category": category,
"chunk_index": chunk.Index,
},
})
stats.Chunks++
}
stats.Files++
if progress != nil {
progress(relPath, stats.Chunks, len(files))
}
}
// Batch upsert to Qdrant
if len(points) > 0 {
for i := 0; i < len(points); i += cfg.BatchSize {
end := i + cfg.BatchSize
if end > len(points) {
end = len(points)
}
batch := points[i:end]
if err := qdrant.UpsertPoints(ctx, cfg.Collection, batch); err != nil {
return stats, log.E("rag.Ingest", fmt.Sprintf("error upserting batch %d", i/cfg.BatchSize+1), err)
}
}
}
return stats, nil
}
// IngestFile processes a single file and stores it in Qdrant.
func IngestFile(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, collection string, filePath string, chunkCfg ChunkConfig) (int, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return 0, log.E("rag.IngestFile", "error reading file", err)
}
if len(strings.TrimSpace(string(content))) == 0 {
return 0, nil
}
category := Category(filePath)
chunks := ChunkMarkdown(string(content), chunkCfg)
var points []Point
for _, chunk := range chunks {
embedding, err := ollama.Embed(ctx, chunk.Text)
if err != nil {
return 0, log.E("rag.IngestFile", fmt.Sprintf("error embedding chunk %d", chunk.Index), err)
}
points = append(points, Point{
ID: ChunkID(filePath, chunk.Index, chunk.Text),
Vector: embedding,
Payload: map[string]any{
"text": chunk.Text,
"source": filePath,
"section": chunk.Section,
"category": category,
"chunk_index": chunk.Index,
},
})
}
if err := qdrant.UpsertPoints(ctx, collection, points); err != nil {
return 0, log.E("rag.IngestFile", "error upserting points", err)
}
return len(points), nil
}

120
ollama.go Normal file
View file

@ -0,0 +1,120 @@
package rag
import (
"context"
"fmt"
"net/http"
"net/url"
"time"
"forge.lthn.ai/core/go/pkg/log"
"github.com/ollama/ollama/api"
)
// OllamaConfig holds Ollama connection configuration.
type OllamaConfig struct {
Host string
Port int
Model string
}
// DefaultOllamaConfig returns default Ollama configuration.
// Host defaults to localhost for local development.
func DefaultOllamaConfig() OllamaConfig {
return OllamaConfig{
Host: "localhost",
Port: 11434,
Model: "nomic-embed-text",
}
}
// OllamaClient wraps the Ollama API client for embeddings.
type OllamaClient struct {
client *api.Client
config OllamaConfig
}
// NewOllamaClient creates a new Ollama client.
func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) {
baseURL := &url.URL{
Scheme: "http",
Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
}
client := api.NewClient(baseURL, &http.Client{
Timeout: 30 * time.Second,
})
return &OllamaClient{
client: client,
config: cfg,
}, nil
}
// EmbedDimension returns the embedding dimension for the configured model.
// nomic-embed-text uses 768 dimensions.
func (o *OllamaClient) EmbedDimension() uint64 {
switch o.config.Model {
case "nomic-embed-text":
return 768
case "mxbai-embed-large":
return 1024
case "all-minilm":
return 384
default:
return 768 // Default to nomic-embed-text dimension
}
}
// Embed generates embeddings for the given text.
func (o *OllamaClient) Embed(ctx context.Context, text string) ([]float32, error) {
req := &api.EmbedRequest{
Model: o.config.Model,
Input: text,
}
resp, err := o.client.Embed(ctx, req)
if err != nil {
return nil, log.E("rag.Ollama.Embed", "failed to generate embedding", err)
}
if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 {
return nil, log.E("rag.Ollama.Embed", "empty embedding response", nil)
}
// Convert float64 to float32 for Qdrant
embedding := resp.Embeddings[0]
result := make([]float32, len(embedding))
for i, v := range embedding {
result[i] = float32(v)
}
return result, nil
}
// EmbedBatch generates embeddings for multiple texts.
func (o *OllamaClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) {
results := make([][]float32, len(texts))
for i, text := range texts {
embedding, err := o.Embed(ctx, text)
if err != nil {
return nil, log.E("rag.Ollama.EmbedBatch", fmt.Sprintf("failed to embed text %d", i), err)
}
results[i] = embedding
}
return results, nil
}
// VerifyModel checks if the embedding model is available.
func (o *OllamaClient) VerifyModel(ctx context.Context) error {
_, err := o.Embed(ctx, "test")
if err != nil {
return log.E("rag.Ollama.VerifyModel", fmt.Sprintf("model %s not available (run: ollama pull %s)", o.config.Model, o.config.Model), err)
}
return nil
}
// Model returns the configured embedding model name.
func (o *OllamaClient) Model() string {
return o.config.Model
}

225
qdrant.go Normal file
View file

@ -0,0 +1,225 @@
// Package rag provides RAG (Retrieval Augmented Generation) functionality
// for storing and querying documentation in Qdrant vector database.
package rag
import (
"context"
"fmt"
"forge.lthn.ai/core/go/pkg/log"
"github.com/qdrant/go-client/qdrant"
)
// QdrantConfig holds Qdrant connection configuration.
type QdrantConfig struct {
Host string
Port int
APIKey string
UseTLS bool
}
// DefaultQdrantConfig returns default Qdrant configuration.
// Host defaults to localhost for local development.
func DefaultQdrantConfig() QdrantConfig {
return QdrantConfig{
Host: "localhost",
Port: 6334, // gRPC port
UseTLS: false,
}
}
// QdrantClient wraps the Qdrant Go client with convenience methods.
type QdrantClient struct {
client *qdrant.Client
config QdrantConfig
}
// NewQdrantClient creates a new Qdrant client.
func NewQdrantClient(cfg QdrantConfig) (*QdrantClient, error) {
addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port)
client, err := qdrant.NewClient(&qdrant.Config{
Host: cfg.Host,
Port: cfg.Port,
APIKey: cfg.APIKey,
UseTLS: cfg.UseTLS,
})
if err != nil {
return nil, log.E("rag.Qdrant", fmt.Sprintf("failed to connect to Qdrant at %s", addr), err)
}
return &QdrantClient{
client: client,
config: cfg,
}, nil
}
// Close closes the Qdrant client connection.
func (q *QdrantClient) Close() error {
return q.client.Close()
}
// HealthCheck verifies the connection to Qdrant.
func (q *QdrantClient) HealthCheck(ctx context.Context) error {
_, err := q.client.HealthCheck(ctx)
return err
}
// ListCollections returns all collection names.
func (q *QdrantClient) ListCollections(ctx context.Context) ([]string, error) {
resp, err := q.client.ListCollections(ctx)
if err != nil {
return nil, err
}
names := make([]string, len(resp))
copy(names, resp)
return names, nil
}
// CollectionExists checks if a collection exists.
func (q *QdrantClient) CollectionExists(ctx context.Context, name string) (bool, error) {
return q.client.CollectionExists(ctx, name)
}
// CreateCollection creates a new collection with cosine distance.
func (q *QdrantClient) CreateCollection(ctx context.Context, name string, vectorSize uint64) error {
return q.client.CreateCollection(ctx, &qdrant.CreateCollection{
CollectionName: name,
VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
Size: vectorSize,
Distance: qdrant.Distance_Cosine,
}),
})
}
// DeleteCollection deletes a collection.
func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error {
return q.client.DeleteCollection(ctx, name)
}
// CollectionInfo returns information about a collection.
func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) {
return q.client.GetCollectionInfo(ctx, name)
}
// Point represents a vector point with payload.
type Point struct {
ID string
Vector []float32
Payload map[string]any
}
// UpsertPoints inserts or updates points in a collection.
func (q *QdrantClient) UpsertPoints(ctx context.Context, collection string, points []Point) error {
if len(points) == 0 {
return nil
}
qdrantPoints := make([]*qdrant.PointStruct, len(points))
for i, p := range points {
qdrantPoints[i] = &qdrant.PointStruct{
Id: qdrant.NewID(p.ID),
Vectors: qdrant.NewVectors(p.Vector...),
Payload: qdrant.NewValueMap(p.Payload),
}
}
_, err := q.client.Upsert(ctx, &qdrant.UpsertPoints{
CollectionName: collection,
Points: qdrantPoints,
})
return err
}
// SearchResult represents a search result with score.
type SearchResult struct {
ID string
Score float32
Payload map[string]any
}
// Search performs a vector similarity search.
func (q *QdrantClient) Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error) {
query := &qdrant.QueryPoints{
CollectionName: collection,
Query: qdrant.NewQuery(vector...),
Limit: qdrant.PtrOf(limit),
WithPayload: qdrant.NewWithPayload(true),
}
// Add filter if provided
if len(filter) > 0 {
conditions := make([]*qdrant.Condition, 0, len(filter))
for k, v := range filter {
conditions = append(conditions, qdrant.NewMatch(k, v))
}
query.Filter = &qdrant.Filter{
Must: conditions,
}
}
resp, err := q.client.Query(ctx, query)
if err != nil {
return nil, err
}
results := make([]SearchResult, len(resp))
for i, p := range resp {
payload := make(map[string]any)
for k, v := range p.Payload {
payload[k] = valueToGo(v)
}
results[i] = SearchResult{
ID: pointIDToString(p.Id),
Score: p.Score,
Payload: payload,
}
}
return results, nil
}
// pointIDToString converts a Qdrant point ID to string.
func pointIDToString(id *qdrant.PointId) string {
if id == nil {
return ""
}
switch v := id.PointIdOptions.(type) {
case *qdrant.PointId_Num:
return fmt.Sprintf("%d", v.Num)
case *qdrant.PointId_Uuid:
return v.Uuid
default:
return ""
}
}
// valueToGo converts a Qdrant value to a Go value.
func valueToGo(v *qdrant.Value) any {
if v == nil {
return nil
}
switch val := v.Kind.(type) {
case *qdrant.Value_StringValue:
return val.StringValue
case *qdrant.Value_IntegerValue:
return val.IntegerValue
case *qdrant.Value_DoubleValue:
return val.DoubleValue
case *qdrant.Value_BoolValue:
return val.BoolValue
case *qdrant.Value_ListValue:
list := make([]any, len(val.ListValue.Values))
for i, item := range val.ListValue.Values {
list[i] = valueToGo(item)
}
return list
case *qdrant.Value_StructValue:
m := make(map[string]any)
for k, item := range val.StructValue.Fields {
m[k] = valueToGo(item)
}
return m
default:
return nil
}
}

163
query.go Normal file
View file

@ -0,0 +1,163 @@
package rag
import (
"context"
"fmt"
"html"
"strings"
"forge.lthn.ai/core/go/pkg/log"
)
// QueryConfig holds query configuration.
type QueryConfig struct {
Collection string
Limit uint64
Threshold float32 // Minimum similarity score (0-1)
Category string // Filter by category
}
// DefaultQueryConfig returns default query configuration.
func DefaultQueryConfig() QueryConfig {
return QueryConfig{
Collection: "hostuk-docs",
Limit: 5,
Threshold: 0.5,
}
}
// QueryResult represents a query result with metadata.
type QueryResult struct {
Text string
Source string
Section string
Category string
ChunkIndex int
Score float32
}
// Query searches for similar documents in Qdrant.
func Query(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, query string, cfg QueryConfig) ([]QueryResult, error) {
// Generate embedding for query
embedding, err := ollama.Embed(ctx, query)
if err != nil {
return nil, log.E("rag.Query", "error generating query embedding", err)
}
// Build filter
var filter map[string]string
if cfg.Category != "" {
filter = map[string]string{"category": cfg.Category}
}
// Search Qdrant
results, err := qdrant.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter)
if err != nil {
return nil, log.E("rag.Query", "error searching", err)
}
// Convert and filter by threshold
var queryResults []QueryResult
for _, r := range results {
if r.Score < cfg.Threshold {
continue
}
qr := QueryResult{
Score: r.Score,
}
// Extract payload fields
if text, ok := r.Payload["text"].(string); ok {
qr.Text = text
}
if source, ok := r.Payload["source"].(string); ok {
qr.Source = source
}
if section, ok := r.Payload["section"].(string); ok {
qr.Section = section
}
if category, ok := r.Payload["category"].(string); ok {
qr.Category = category
}
// Handle chunk_index from various types (JSON unmarshaling produces float64)
switch idx := r.Payload["chunk_index"].(type) {
case int64:
qr.ChunkIndex = int(idx)
case float64:
qr.ChunkIndex = int(idx)
case int:
qr.ChunkIndex = idx
}
queryResults = append(queryResults, qr)
}
return queryResults, nil
}
// FormatResultsText formats query results as plain text.
func FormatResultsText(results []QueryResult) string {
if len(results) == 0 {
return "No results found."
}
var sb strings.Builder
for i, r := range results {
sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score))
sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source))
if r.Section != "" {
sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section))
}
sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category))
sb.WriteString(r.Text)
sb.WriteString("\n")
}
return sb.String()
}
// FormatResultsContext formats query results for LLM context injection.
func FormatResultsContext(results []QueryResult) string {
if len(results) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("<retrieved_context>\n")
for _, r := range results {
// Escape XML special characters to prevent malformed output
fmt.Fprintf(&sb, "<document source=\"%s\" section=\"%s\" category=\"%s\">\n",
html.EscapeString(r.Source),
html.EscapeString(r.Section),
html.EscapeString(r.Category))
sb.WriteString(html.EscapeString(r.Text))
sb.WriteString("\n</document>\n\n")
}
sb.WriteString("</retrieved_context>")
return sb.String()
}
// FormatResultsJSON formats query results as JSON-like output.
func FormatResultsJSON(results []QueryResult) string {
if len(results) == 0 {
return "[]"
}
var sb strings.Builder
sb.WriteString("[\n")
for i, r := range results {
sb.WriteString(" {\n")
sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source))
sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section))
sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category))
sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score))
sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text))
if i < len(results)-1 {
sb.WriteString(" },\n")
} else {
sb.WriteString(" }\n")
}
}
sb.WriteString("]")
return sb.String()
}