commit bf047e44942d0a3832e611f7b03b542ce2b5a587 Author: Snider Date: Thu Feb 19 18:29:59 2026 +0000 feat: extract go-rag from go-ai as standalone RAG package Vector search with Qdrant + Ollama embeddings, document chunking. Zero internal go-ai dependencies. Adds CLAUDE.md/TODO.md/FINDINGS.md. Co-Authored-By: Virgil diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..9b1c862 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,52 @@ +# CLAUDE.md + +## What This Is + +Retrieval-Augmented Generation with vector search. Module: `forge.lthn.ai/core/go-rag` + +Provides document chunking, embedding via Ollama, vector storage/search via Qdrant, and formatted context retrieval for AI prompts. + +## Commands + +```bash +go test ./... # Run all tests +go test -v -run TestChunk # Single test +``` + +## Architecture + +| File | Purpose | +|------|---------| +| `chunk.go` | Document chunking — splits markdown/text into semantic chunks | +| `ingest.go` | Ingestion pipeline — reads files, chunks, embeds, stores | +| `query.go` | Query interface — search vectors, format results as text/JSON/XML | +| `qdrant.go` | Qdrant vector DB client — create collections, upsert, search | +| `ollama.go` | Ollama embedding client — generate embeddings for chunks | +| `helpers.go` | Shared utilities | + +## Dependencies + +- `forge.lthn.ai/core/go` — Logging (pkg/log) +- `github.com/ollama/ollama` — Embedding API client +- `github.com/qdrant/go-client` — Vector DB gRPC client +- `github.com/stretchr/testify` — Tests + +## Key API + +```go +// Ingest documents +rag.IngestFile(ctx, cfg, "/path/to/doc.md") +rag.Ingest(ctx, cfg, reader, "source-name") + +// Query for relevant context +results, err := rag.Query(ctx, cfg, "search query") +context := rag.FormatResults(results, "text") // or "json", "xml" +``` + +## Coding Standards + +- UK English +- Tests: testify assert/require +- Conventional commits +- Co-Author: `Co-Authored-By: Virgil ` +- Licence: EUPL-1.2 diff --git a/FINDINGS.md b/FINDINGS.md new file mode 100644 index 0000000..7a8e292 --- /dev/null +++ b/FINDINGS.md @@ -0,0 +1,21 @@ +# FINDINGS.md — go-rag Research & Discovery + +## 2026-02-19: Split from go-ai (Virgil) + +### Origin + +Extracted from `forge.lthn.ai/core/go-ai/rag/`. Zero internal go-ai dependencies. + +### What Was Extracted + +- 7 Go files (~1,017 LOC excluding tests) +- 1 test file (chunk_test.go) + +### Key Finding: Minimal Test Coverage + +Only chunk.go has tests. The Qdrant and Ollama clients are untested — they depend on external services (Qdrant server, Ollama API) which makes unit testing harder. Consider mock interfaces. + +### Consumers + +- `go-ai/ai/rag.go` wraps this as `QueryRAGForTask()` facade +- `go-ai/mcp/tools_rag.go` exposes RAG as MCP tools diff --git a/TODO.md b/TODO.md new file mode 100644 index 0000000..bfeedca --- /dev/null +++ b/TODO.md @@ -0,0 +1,22 @@ +# TODO.md — go-rag Task Queue + +## Phase 1: Post-Split + +- [ ] **Verify tests pass standalone** — Run `go test ./...`. Confirm chunk_test.go passes. +- [ ] **Add missing tests** — Only chunk.go has tests. Need: query, ingest, qdrant client, ollama client. +- [ ] **Benchmark chunking** — No perf baselines. Add BenchmarkChunk for various document sizes. + +## Phase 2: Enhancements + +- [ ] **Embedding model selection** — Currently hardcoded to Ollama. Add backend interface for alternative embedding providers. +- [ ] **Collection management** — Add list/delete collection operations for Qdrant. +- [ ] **Hybrid search** — Combine vector similarity with keyword matching for better recall. +- [ ] **Chunk overlap** — Current chunking may lose context at boundaries. Add configurable overlap. + +--- + +## Workflow + +1. Virgil in core/go writes tasks here after research +2. This repo's session picks up tasks in phase order +3. Mark `[x]` when done, note commit hash diff --git a/chunk.go b/chunk.go new file mode 100644 index 0000000..fbcc3c9 --- /dev/null +++ b/chunk.go @@ -0,0 +1,204 @@ +package rag + +import ( + "crypto/md5" + "fmt" + "path/filepath" + "slices" + "strings" +) + +// ChunkConfig holds chunking configuration. +type ChunkConfig struct { + Size int // Characters per chunk + Overlap int // Overlap between chunks +} + +// DefaultChunkConfig returns default chunking configuration. +func DefaultChunkConfig() ChunkConfig { + return ChunkConfig{ + Size: 500, + Overlap: 50, + } +} + +// Chunk represents a text chunk with metadata. +type Chunk struct { + Text string + Section string + Index int +} + +// ChunkMarkdown splits markdown text into chunks by sections and paragraphs. +// Preserves context with configurable overlap. +func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk { + if cfg.Size <= 0 { + cfg.Size = 500 + } + if cfg.Overlap < 0 || cfg.Overlap >= cfg.Size { + cfg.Overlap = 0 + } + + var chunks []Chunk + + // Split by ## headers + sections := splitBySections(text) + + chunkIndex := 0 + for _, section := range sections { + section = strings.TrimSpace(section) + if section == "" { + continue + } + + // Extract section title + lines := strings.SplitN(section, "\n", 2) + title := "" + if strings.HasPrefix(lines[0], "#") { + title = strings.TrimLeft(lines[0], "#") + title = strings.TrimSpace(title) + } + + // If section is small enough, yield as-is + if len(section) <= cfg.Size { + chunks = append(chunks, Chunk{ + Text: section, + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + continue + } + + // Otherwise, chunk by paragraphs + paragraphs := splitByParagraphs(section) + currentChunk := "" + + for _, para := range paragraphs { + para = strings.TrimSpace(para) + if para == "" { + continue + } + + if len(currentChunk)+len(para)+2 <= cfg.Size { + if currentChunk != "" { + currentChunk += "\n\n" + para + } else { + currentChunk = para + } + } else { + if currentChunk != "" { + chunks = append(chunks, Chunk{ + Text: strings.TrimSpace(currentChunk), + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + } + // Start new chunk with overlap from previous (rune-safe for UTF-8) + runes := []rune(currentChunk) + if cfg.Overlap > 0 && len(runes) > cfg.Overlap { + overlapText := string(runes[len(runes)-cfg.Overlap:]) + currentChunk = overlapText + "\n\n" + para + } else { + currentChunk = para + } + } + } + + // Don't forget the last chunk + if strings.TrimSpace(currentChunk) != "" { + chunks = append(chunks, Chunk{ + Text: strings.TrimSpace(currentChunk), + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + } + } + + return chunks +} + +// splitBySections splits text by ## headers while preserving the header with its content. +func splitBySections(text string) []string { + var sections []string + lines := strings.Split(text, "\n") + + var currentSection strings.Builder + for _, line := range lines { + // Check if this line is a ## header + if strings.HasPrefix(line, "## ") { + // Save previous section if exists + if currentSection.Len() > 0 { + sections = append(sections, currentSection.String()) + currentSection.Reset() + } + } + currentSection.WriteString(line) + currentSection.WriteString("\n") + } + + // Don't forget the last section + if currentSection.Len() > 0 { + sections = append(sections, currentSection.String()) + } + + return sections +} + +// splitByParagraphs splits text by double newlines. +func splitByParagraphs(text string) []string { + // Replace multiple newlines with a marker, then split + normalized := text + for strings.Contains(normalized, "\n\n\n") { + normalized = strings.ReplaceAll(normalized, "\n\n\n", "\n\n") + } + return strings.Split(normalized, "\n\n") +} + +// Category determines the document category from file path. +func Category(path string) string { + lower := strings.ToLower(path) + + switch { + case strings.Contains(lower, "flux") || strings.Contains(lower, "ui/component"): + return "ui-component" + case strings.Contains(lower, "brand") || strings.Contains(lower, "mascot"): + return "brand" + case strings.Contains(lower, "brief"): + return "product-brief" + case strings.Contains(lower, "help") || strings.Contains(lower, "draft"): + return "help-doc" + case strings.Contains(lower, "task") || strings.Contains(lower, "plan"): + return "task" + case strings.Contains(lower, "architecture") || strings.Contains(lower, "migration"): + return "architecture" + default: + return "documentation" + } +} + +// ChunkID generates a unique ID for a chunk. +func ChunkID(path string, index int, text string) string { + // Use first 100 runes of text for uniqueness (rune-safe for UTF-8) + runes := []rune(text) + if len(runes) > 100 { + runes = runes[:100] + } + textPart := string(runes) + data := fmt.Sprintf("%s:%d:%s", path, index, textPart) + hash := md5.Sum([]byte(data)) + return fmt.Sprintf("%x", hash) +} + +// FileExtensions returns the file extensions to process. +func FileExtensions() []string { + return []string{".md", ".markdown", ".txt"} +} + +// ShouldProcess checks if a file should be processed based on extension. +func ShouldProcess(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + return slices.Contains(FileExtensions(), ext) +} diff --git a/chunk_test.go b/chunk_test.go new file mode 100644 index 0000000..87fd5c0 --- /dev/null +++ b/chunk_test.go @@ -0,0 +1,120 @@ +package rag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestChunkMarkdown_Good_SmallSection(t *testing.T) { + text := `# Title + +This is a small section that fits in one chunk. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Contains(t, chunks[0].Text, "small section") +} + +func TestChunkMarkdown_Good_MultipleSections(t *testing.T) { + text := `# Main Title + +Introduction paragraph. + +## Section One + +Content for section one. + +## Section Two + +Content for section two. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.GreaterOrEqual(t, len(chunks), 2) +} + +func TestChunkMarkdown_Good_LargeSection(t *testing.T) { + // Create a section larger than chunk size + text := `## Large Section + +` + repeatString("This is a test paragraph with some content. ", 50) + + cfg := ChunkConfig{Size: 200, Overlap: 20} + chunks := ChunkMarkdown(text, cfg) + + assert.Greater(t, len(chunks), 1) + for _, chunk := range chunks { + assert.NotEmpty(t, chunk.Text) + assert.Equal(t, "Large Section", chunk.Section) + } +} + +func TestChunkMarkdown_Good_ExtractsTitle(t *testing.T) { + text := `## My Section Title + +Some content here. +` + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Equal(t, "My Section Title", chunks[0].Section) +} + +func TestCategory_Good_UIComponent(t *testing.T) { + tests := []struct { + path string + expected string + }{ + {"docs/flux/button.md", "ui-component"}, + {"ui/components/modal.md", "ui-component"}, + {"brand/vi-personality.md", "brand"}, + {"mascot/expressions.md", "brand"}, + {"product-brief.md", "product-brief"}, + {"tasks/2024-01-15-feature.md", "task"}, + {"plans/architecture.md", "task"}, + {"architecture/migration.md", "architecture"}, + {"docs/api.md", "documentation"}, + } + + for _, tc := range tests { + t.Run(tc.path, func(t *testing.T) { + assert.Equal(t, tc.expected, Category(tc.path)) + }) + } +} + +func TestChunkID_Good_Deterministic(t *testing.T) { + id1 := ChunkID("test.md", 0, "hello world") + id2 := ChunkID("test.md", 0, "hello world") + + assert.Equal(t, id1, id2) +} + +func TestChunkID_Good_DifferentForDifferentInputs(t *testing.T) { + id1 := ChunkID("test.md", 0, "hello world") + id2 := ChunkID("test.md", 1, "hello world") + id3 := ChunkID("other.md", 0, "hello world") + + assert.NotEqual(t, id1, id2) + assert.NotEqual(t, id1, id3) +} + +func TestShouldProcess_Good_MarkdownFiles(t *testing.T) { + assert.True(t, ShouldProcess("doc.md")) + assert.True(t, ShouldProcess("doc.markdown")) + assert.True(t, ShouldProcess("doc.txt")) + assert.False(t, ShouldProcess("doc.go")) + assert.False(t, ShouldProcess("doc.py")) + assert.False(t, ShouldProcess("doc")) +} + +// Helper function +func repeatString(s string, n int) string { + result := "" + for i := 0; i < n; i++ { + result += s + } + return result +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7306181 --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module forge.lthn.ai/core/go-rag + +go 1.25.5 + +require ( + forge.lthn.ai/core/go v0.0.0 + github.com/ollama/ollama v0.16.1 + github.com/qdrant/go-client v1.16.2 + github.com/stretchr/testify v1.11.1 +) + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251111163417-95abcf5c77ba // indirect + google.golang.org/grpc v1.78.0 // indirect + google.golang.org/protobuf v1.36.11 // indirect +) + +replace forge.lthn.ai/core/go => ../core diff --git a/helpers.go b/helpers.go new file mode 100644 index 0000000..8d6b81f --- /dev/null +++ b/helpers.go @@ -0,0 +1,89 @@ +package rag + +import ( + "context" + "fmt" +) + +// QueryDocs queries the RAG database with default clients. +func QueryDocs(ctx context.Context, question, collectionName string, topK int) ([]QueryResult, error) { + qdrantClient, err := NewQdrantClient(DefaultQdrantConfig()) + if err != nil { + return nil, err + } + defer func() { _ = qdrantClient.Close() }() + + ollamaClient, err := NewOllamaClient(DefaultOllamaConfig()) + if err != nil { + return nil, err + } + + cfg := DefaultQueryConfig() + cfg.Collection = collectionName + cfg.Limit = uint64(topK) + + return Query(ctx, qdrantClient, ollamaClient, question, cfg) +} + +// QueryDocsContext queries the RAG database and returns context-formatted results. +func QueryDocsContext(ctx context.Context, question, collectionName string, topK int) (string, error) { + results, err := QueryDocs(ctx, question, collectionName, topK) + if err != nil { + return "", err + } + return FormatResultsContext(results), nil +} + +// IngestDirectory ingests all documents in a directory with default clients. +func IngestDirectory(ctx context.Context, directory, collectionName string, recreateCollection bool) error { + qdrantClient, err := NewQdrantClient(DefaultQdrantConfig()) + if err != nil { + return err + } + defer func() { _ = qdrantClient.Close() }() + + if err := qdrantClient.HealthCheck(ctx); err != nil { + return fmt.Errorf("qdrant health check failed: %w", err) + } + + ollamaClient, err := NewOllamaClient(DefaultOllamaConfig()) + if err != nil { + return err + } + + if err := ollamaClient.VerifyModel(ctx); err != nil { + return err + } + + cfg := DefaultIngestConfig() + cfg.Directory = directory + cfg.Collection = collectionName + cfg.Recreate = recreateCollection + + _, err = Ingest(ctx, qdrantClient, ollamaClient, cfg, nil) + return err +} + +// IngestSingleFile ingests a single file with default clients. +func IngestSingleFile(ctx context.Context, filePath, collectionName string) (int, error) { + qdrantClient, err := NewQdrantClient(DefaultQdrantConfig()) + if err != nil { + return 0, err + } + defer func() { _ = qdrantClient.Close() }() + + if err := qdrantClient.HealthCheck(ctx); err != nil { + return 0, fmt.Errorf("qdrant health check failed: %w", err) + } + + ollamaClient, err := NewOllamaClient(DefaultOllamaConfig()) + if err != nil { + return 0, err + } + + if err := ollamaClient.VerifyModel(ctx); err != nil { + return 0, err + } + + return IngestFile(ctx, qdrantClient, ollamaClient, collectionName, filePath, DefaultChunkConfig()) +} diff --git a/ingest.go b/ingest.go new file mode 100644 index 0000000..cd4ff06 --- /dev/null +++ b/ingest.go @@ -0,0 +1,216 @@ +package rag + +import ( + "context" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "forge.lthn.ai/core/go/pkg/log" +) + +// IngestConfig holds ingestion configuration. +type IngestConfig struct { + Directory string + Collection string + Recreate bool + Verbose bool + BatchSize int + Chunk ChunkConfig +} + +// DefaultIngestConfig returns default ingestion configuration. +func DefaultIngestConfig() IngestConfig { + return IngestConfig{ + Collection: "hostuk-docs", + BatchSize: 100, + Chunk: DefaultChunkConfig(), + } +} + +// IngestStats holds statistics from ingestion. +type IngestStats struct { + Files int + Chunks int + Errors int +} + +// IngestProgress is called during ingestion to report progress. +type IngestProgress func(file string, chunks int, total int) + +// Ingest processes a directory of documents and stores them in Qdrant. +func Ingest(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, cfg IngestConfig, progress IngestProgress) (*IngestStats, error) { + stats := &IngestStats{} + + // Validate batch size to prevent infinite loop + if cfg.BatchSize <= 0 { + cfg.BatchSize = 100 // Safe default + } + + // Resolve directory + absDir, err := filepath.Abs(cfg.Directory) + if err != nil { + return nil, log.E("rag.Ingest", "error resolving directory", err) + } + + info, err := os.Stat(absDir) + if err != nil { + return nil, log.E("rag.Ingest", "error accessing directory", err) + } + if !info.IsDir() { + return nil, log.E("rag.Ingest", fmt.Sprintf("not a directory: %s", absDir), nil) + } + + // Check/create collection + exists, err := qdrant.CollectionExists(ctx, cfg.Collection) + if err != nil { + return nil, log.E("rag.Ingest", "error checking collection", err) + } + + if cfg.Recreate && exists { + if err := qdrant.DeleteCollection(ctx, cfg.Collection); err != nil { + return nil, log.E("rag.Ingest", "error deleting collection", err) + } + exists = false + } + + if !exists { + vectorDim := ollama.EmbedDimension() + if err := qdrant.CreateCollection(ctx, cfg.Collection, vectorDim); err != nil { + return nil, log.E("rag.Ingest", "error creating collection", err) + } + } + + // Find markdown files + var files []string + err = filepath.WalkDir(absDir, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() && ShouldProcess(path) { + files = append(files, path) + } + return nil + }) + if err != nil { + return nil, log.E("rag.Ingest", "error walking directory", err) + } + + if len(files) == 0 { + return nil, log.E("rag.Ingest", fmt.Sprintf("no markdown files found in %s", absDir), nil) + } + + // Process files + var points []Point + for _, filePath := range files { + relPath, err := filepath.Rel(absDir, filePath) + if err != nil { + stats.Errors++ + continue + } + + content, err := os.ReadFile(filePath) + if err != nil { + stats.Errors++ + continue + } + + if len(strings.TrimSpace(string(content))) == 0 { + continue + } + + // Chunk the content + category := Category(relPath) + chunks := ChunkMarkdown(string(content), cfg.Chunk) + + for _, chunk := range chunks { + // Generate embedding + embedding, err := ollama.Embed(ctx, chunk.Text) + if err != nil { + stats.Errors++ + if cfg.Verbose { + fmt.Printf(" Error embedding %s chunk %d: %v\n", relPath, chunk.Index, err) + } + continue + } + + // Create point + points = append(points, Point{ + ID: ChunkID(relPath, chunk.Index, chunk.Text), + Vector: embedding, + Payload: map[string]any{ + "text": chunk.Text, + "source": relPath, + "section": chunk.Section, + "category": category, + "chunk_index": chunk.Index, + }, + }) + stats.Chunks++ + } + + stats.Files++ + if progress != nil { + progress(relPath, stats.Chunks, len(files)) + } + } + + // Batch upsert to Qdrant + if len(points) > 0 { + for i := 0; i < len(points); i += cfg.BatchSize { + end := i + cfg.BatchSize + if end > len(points) { + end = len(points) + } + batch := points[i:end] + if err := qdrant.UpsertPoints(ctx, cfg.Collection, batch); err != nil { + return stats, log.E("rag.Ingest", fmt.Sprintf("error upserting batch %d", i/cfg.BatchSize+1), err) + } + } + } + + return stats, nil +} + +// IngestFile processes a single file and stores it in Qdrant. +func IngestFile(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, collection string, filePath string, chunkCfg ChunkConfig) (int, error) { + content, err := os.ReadFile(filePath) + if err != nil { + return 0, log.E("rag.IngestFile", "error reading file", err) + } + + if len(strings.TrimSpace(string(content))) == 0 { + return 0, nil + } + + category := Category(filePath) + chunks := ChunkMarkdown(string(content), chunkCfg) + + var points []Point + for _, chunk := range chunks { + embedding, err := ollama.Embed(ctx, chunk.Text) + if err != nil { + return 0, log.E("rag.IngestFile", fmt.Sprintf("error embedding chunk %d", chunk.Index), err) + } + + points = append(points, Point{ + ID: ChunkID(filePath, chunk.Index, chunk.Text), + Vector: embedding, + Payload: map[string]any{ + "text": chunk.Text, + "source": filePath, + "section": chunk.Section, + "category": category, + "chunk_index": chunk.Index, + }, + }) + } + + if err := qdrant.UpsertPoints(ctx, collection, points); err != nil { + return 0, log.E("rag.IngestFile", "error upserting points", err) + } + + return len(points), nil +} diff --git a/ollama.go b/ollama.go new file mode 100644 index 0000000..891c830 --- /dev/null +++ b/ollama.go @@ -0,0 +1,120 @@ +package rag + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" + + "forge.lthn.ai/core/go/pkg/log" + "github.com/ollama/ollama/api" +) + +// OllamaConfig holds Ollama connection configuration. +type OllamaConfig struct { + Host string + Port int + Model string +} + +// DefaultOllamaConfig returns default Ollama configuration. +// Host defaults to localhost for local development. +func DefaultOllamaConfig() OllamaConfig { + return OllamaConfig{ + Host: "localhost", + Port: 11434, + Model: "nomic-embed-text", + } +} + +// OllamaClient wraps the Ollama API client for embeddings. +type OllamaClient struct { + client *api.Client + config OllamaConfig +} + +// NewOllamaClient creates a new Ollama client. +func NewOllamaClient(cfg OllamaConfig) (*OllamaClient, error) { + baseURL := &url.URL{ + Scheme: "http", + Host: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + } + + client := api.NewClient(baseURL, &http.Client{ + Timeout: 30 * time.Second, + }) + + return &OllamaClient{ + client: client, + config: cfg, + }, nil +} + +// EmbedDimension returns the embedding dimension for the configured model. +// nomic-embed-text uses 768 dimensions. +func (o *OllamaClient) EmbedDimension() uint64 { + switch o.config.Model { + case "nomic-embed-text": + return 768 + case "mxbai-embed-large": + return 1024 + case "all-minilm": + return 384 + default: + return 768 // Default to nomic-embed-text dimension + } +} + +// Embed generates embeddings for the given text. +func (o *OllamaClient) Embed(ctx context.Context, text string) ([]float32, error) { + req := &api.EmbedRequest{ + Model: o.config.Model, + Input: text, + } + + resp, err := o.client.Embed(ctx, req) + if err != nil { + return nil, log.E("rag.Ollama.Embed", "failed to generate embedding", err) + } + + if len(resp.Embeddings) == 0 || len(resp.Embeddings[0]) == 0 { + return nil, log.E("rag.Ollama.Embed", "empty embedding response", nil) + } + + // Convert float64 to float32 for Qdrant + embedding := resp.Embeddings[0] + result := make([]float32, len(embedding)) + for i, v := range embedding { + result[i] = float32(v) + } + + return result, nil +} + +// EmbedBatch generates embeddings for multiple texts. +func (o *OllamaClient) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + results := make([][]float32, len(texts)) + for i, text := range texts { + embedding, err := o.Embed(ctx, text) + if err != nil { + return nil, log.E("rag.Ollama.EmbedBatch", fmt.Sprintf("failed to embed text %d", i), err) + } + results[i] = embedding + } + return results, nil +} + +// VerifyModel checks if the embedding model is available. +func (o *OllamaClient) VerifyModel(ctx context.Context) error { + _, err := o.Embed(ctx, "test") + if err != nil { + return log.E("rag.Ollama.VerifyModel", fmt.Sprintf("model %s not available (run: ollama pull %s)", o.config.Model, o.config.Model), err) + } + return nil +} + +// Model returns the configured embedding model name. +func (o *OllamaClient) Model() string { + return o.config.Model +} diff --git a/qdrant.go b/qdrant.go new file mode 100644 index 0000000..14a540e --- /dev/null +++ b/qdrant.go @@ -0,0 +1,225 @@ +// Package rag provides RAG (Retrieval Augmented Generation) functionality +// for storing and querying documentation in Qdrant vector database. +package rag + +import ( + "context" + "fmt" + + "forge.lthn.ai/core/go/pkg/log" + "github.com/qdrant/go-client/qdrant" +) + +// QdrantConfig holds Qdrant connection configuration. +type QdrantConfig struct { + Host string + Port int + APIKey string + UseTLS bool +} + +// DefaultQdrantConfig returns default Qdrant configuration. +// Host defaults to localhost for local development. +func DefaultQdrantConfig() QdrantConfig { + return QdrantConfig{ + Host: "localhost", + Port: 6334, // gRPC port + UseTLS: false, + } +} + +// QdrantClient wraps the Qdrant Go client with convenience methods. +type QdrantClient struct { + client *qdrant.Client + config QdrantConfig +} + +// NewQdrantClient creates a new Qdrant client. +func NewQdrantClient(cfg QdrantConfig) (*QdrantClient, error) { + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + + client, err := qdrant.NewClient(&qdrant.Config{ + Host: cfg.Host, + Port: cfg.Port, + APIKey: cfg.APIKey, + UseTLS: cfg.UseTLS, + }) + if err != nil { + return nil, log.E("rag.Qdrant", fmt.Sprintf("failed to connect to Qdrant at %s", addr), err) + } + + return &QdrantClient{ + client: client, + config: cfg, + }, nil +} + +// Close closes the Qdrant client connection. +func (q *QdrantClient) Close() error { + return q.client.Close() +} + +// HealthCheck verifies the connection to Qdrant. +func (q *QdrantClient) HealthCheck(ctx context.Context) error { + _, err := q.client.HealthCheck(ctx) + return err +} + +// ListCollections returns all collection names. +func (q *QdrantClient) ListCollections(ctx context.Context) ([]string, error) { + resp, err := q.client.ListCollections(ctx) + if err != nil { + return nil, err + } + names := make([]string, len(resp)) + copy(names, resp) + return names, nil +} + +// CollectionExists checks if a collection exists. +func (q *QdrantClient) CollectionExists(ctx context.Context, name string) (bool, error) { + return q.client.CollectionExists(ctx, name) +} + +// CreateCollection creates a new collection with cosine distance. +func (q *QdrantClient) CreateCollection(ctx context.Context, name string, vectorSize uint64) error { + return q.client.CreateCollection(ctx, &qdrant.CreateCollection{ + CollectionName: name, + VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ + Size: vectorSize, + Distance: qdrant.Distance_Cosine, + }), + }) +} + +// DeleteCollection deletes a collection. +func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error { + return q.client.DeleteCollection(ctx, name) +} + +// CollectionInfo returns information about a collection. +func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) { + return q.client.GetCollectionInfo(ctx, name) +} + +// Point represents a vector point with payload. +type Point struct { + ID string + Vector []float32 + Payload map[string]any +} + +// UpsertPoints inserts or updates points in a collection. +func (q *QdrantClient) UpsertPoints(ctx context.Context, collection string, points []Point) error { + if len(points) == 0 { + return nil + } + + qdrantPoints := make([]*qdrant.PointStruct, len(points)) + for i, p := range points { + qdrantPoints[i] = &qdrant.PointStruct{ + Id: qdrant.NewID(p.ID), + Vectors: qdrant.NewVectors(p.Vector...), + Payload: qdrant.NewValueMap(p.Payload), + } + } + + _, err := q.client.Upsert(ctx, &qdrant.UpsertPoints{ + CollectionName: collection, + Points: qdrantPoints, + }) + return err +} + +// SearchResult represents a search result with score. +type SearchResult struct { + ID string + Score float32 + Payload map[string]any +} + +// Search performs a vector similarity search. +func (q *QdrantClient) Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error) { + query := &qdrant.QueryPoints{ + CollectionName: collection, + Query: qdrant.NewQuery(vector...), + Limit: qdrant.PtrOf(limit), + WithPayload: qdrant.NewWithPayload(true), + } + + // Add filter if provided + if len(filter) > 0 { + conditions := make([]*qdrant.Condition, 0, len(filter)) + for k, v := range filter { + conditions = append(conditions, qdrant.NewMatch(k, v)) + } + query.Filter = &qdrant.Filter{ + Must: conditions, + } + } + + resp, err := q.client.Query(ctx, query) + if err != nil { + return nil, err + } + + results := make([]SearchResult, len(resp)) + for i, p := range resp { + payload := make(map[string]any) + for k, v := range p.Payload { + payload[k] = valueToGo(v) + } + results[i] = SearchResult{ + ID: pointIDToString(p.Id), + Score: p.Score, + Payload: payload, + } + } + return results, nil +} + +// pointIDToString converts a Qdrant point ID to string. +func pointIDToString(id *qdrant.PointId) string { + if id == nil { + return "" + } + switch v := id.PointIdOptions.(type) { + case *qdrant.PointId_Num: + return fmt.Sprintf("%d", v.Num) + case *qdrant.PointId_Uuid: + return v.Uuid + default: + return "" + } +} + +// valueToGo converts a Qdrant value to a Go value. +func valueToGo(v *qdrant.Value) any { + if v == nil { + return nil + } + switch val := v.Kind.(type) { + case *qdrant.Value_StringValue: + return val.StringValue + case *qdrant.Value_IntegerValue: + return val.IntegerValue + case *qdrant.Value_DoubleValue: + return val.DoubleValue + case *qdrant.Value_BoolValue: + return val.BoolValue + case *qdrant.Value_ListValue: + list := make([]any, len(val.ListValue.Values)) + for i, item := range val.ListValue.Values { + list[i] = valueToGo(item) + } + return list + case *qdrant.Value_StructValue: + m := make(map[string]any) + for k, item := range val.StructValue.Fields { + m[k] = valueToGo(item) + } + return m + default: + return nil + } +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..2605868 --- /dev/null +++ b/query.go @@ -0,0 +1,163 @@ +package rag + +import ( + "context" + "fmt" + "html" + "strings" + + "forge.lthn.ai/core/go/pkg/log" +) + +// QueryConfig holds query configuration. +type QueryConfig struct { + Collection string + Limit uint64 + Threshold float32 // Minimum similarity score (0-1) + Category string // Filter by category +} + +// DefaultQueryConfig returns default query configuration. +func DefaultQueryConfig() QueryConfig { + return QueryConfig{ + Collection: "hostuk-docs", + Limit: 5, + Threshold: 0.5, + } +} + +// QueryResult represents a query result with metadata. +type QueryResult struct { + Text string + Source string + Section string + Category string + ChunkIndex int + Score float32 +} + +// Query searches for similar documents in Qdrant. +func Query(ctx context.Context, qdrant *QdrantClient, ollama *OllamaClient, query string, cfg QueryConfig) ([]QueryResult, error) { + // Generate embedding for query + embedding, err := ollama.Embed(ctx, query) + if err != nil { + return nil, log.E("rag.Query", "error generating query embedding", err) + } + + // Build filter + var filter map[string]string + if cfg.Category != "" { + filter = map[string]string{"category": cfg.Category} + } + + // Search Qdrant + results, err := qdrant.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter) + if err != nil { + return nil, log.E("rag.Query", "error searching", err) + } + + // Convert and filter by threshold + var queryResults []QueryResult + for _, r := range results { + if r.Score < cfg.Threshold { + continue + } + + qr := QueryResult{ + Score: r.Score, + } + + // Extract payload fields + if text, ok := r.Payload["text"].(string); ok { + qr.Text = text + } + if source, ok := r.Payload["source"].(string); ok { + qr.Source = source + } + if section, ok := r.Payload["section"].(string); ok { + qr.Section = section + } + if category, ok := r.Payload["category"].(string); ok { + qr.Category = category + } + // Handle chunk_index from various types (JSON unmarshaling produces float64) + switch idx := r.Payload["chunk_index"].(type) { + case int64: + qr.ChunkIndex = int(idx) + case float64: + qr.ChunkIndex = int(idx) + case int: + qr.ChunkIndex = idx + } + + queryResults = append(queryResults, qr) + } + + return queryResults, nil +} + +// FormatResultsText formats query results as plain text. +func FormatResultsText(results []QueryResult) string { + if len(results) == 0 { + return "No results found." + } + + var sb strings.Builder + for i, r := range results { + sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score)) + sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source)) + if r.Section != "" { + sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section)) + } + sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category)) + sb.WriteString(r.Text) + sb.WriteString("\n") + } + return sb.String() +} + +// FormatResultsContext formats query results for LLM context injection. +func FormatResultsContext(results []QueryResult) string { + if len(results) == 0 { + return "" + } + + var sb strings.Builder + sb.WriteString("\n") + for _, r := range results { + // Escape XML special characters to prevent malformed output + fmt.Fprintf(&sb, "\n", + html.EscapeString(r.Source), + html.EscapeString(r.Section), + html.EscapeString(r.Category)) + sb.WriteString(html.EscapeString(r.Text)) + sb.WriteString("\n\n\n") + } + sb.WriteString("") + return sb.String() +} + +// FormatResultsJSON formats query results as JSON-like output. +func FormatResultsJSON(results []QueryResult) string { + if len(results) == 0 { + return "[]" + } + + var sb strings.Builder + sb.WriteString("[\n") + for i, r := range results { + sb.WriteString(" {\n") + sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source)) + sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section)) + sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category)) + sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score)) + sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text)) + if i < len(results)-1 { + sb.WriteString(" },\n") + } else { + sb.WriteString(" }\n") + } + } + sb.WriteString("]") + return sb.String() +}