diff --git a/TODO.md b/TODO.md index 5aa9810..93d17fd 100644 --- a/TODO.md +++ b/TODO.md @@ -43,36 +43,36 @@ All tasks are pure Go, testable with existing mocks. No external services needed ### 3.1 Chunk Boundary Improvements -- [ ] **Sentence-aware splitting** — When a paragraph exceeds `ChunkConfig.Size`, split at sentence boundaries (`. `, `? `, `! `) instead of adding the whole paragraph as an oversized chunk. Keep current behaviour as fallback when no sentence boundaries exist. -- [ ] **Overlap boundary alignment** — Current overlap slices by rune count from the end of the previous chunk. Improve by aligning overlap to word boundaries (find the nearest space before the overlap point) to avoid splitting mid-word. -- [ ] **Tests** — (a) Sentence splitting with 3 sentences > Size, (b) overlap word boundary alignment, (c) existing tests still pass (no regression). +- [x] **Sentence-aware splitting** — When a paragraph exceeds `ChunkConfig.Size`, split at sentence boundaries (`. `, `? `, `! `) instead of adding the whole paragraph as an oversized chunk. Keep current behaviour as fallback when no sentence boundaries exist. (cf26e88) +- [x] **Overlap boundary alignment** — Current overlap slices by rune count from the end of the previous chunk. Improve by aligning overlap to word boundaries (find the nearest space before the overlap point) to avoid splitting mid-word. (cf26e88) +- [x] **Tests** — (a) Sentence splitting with 3 sentences > Size, (b) overlap word boundary alignment, (c) existing tests still pass (no regression). (cf26e88) ### 3.2 Collection Management Helpers -- [ ] **Create `collections.go`** — Helper functions for collection lifecycle: +- [x] **Create `collections.go`** — Helper functions for collection lifecycle: - `ListCollections(ctx, store VectorStore) ([]string, error)` — wraps store method - `DeleteCollection(ctx, store VectorStore, name string) error` — wraps store method - - `CollectionStats(ctx, store VectorStore, name string) (*CollectionInfo, error)` — point count, vector size, status. Needs `CollectionInfo` struct (not Qdrant-specific). -- [ ] **Add `ListCollections` and `DeleteCollection` to VectorStore interface** — Currently these methods exist on `QdrantClient` but NOT on the `VectorStore` interface. Add them and update mock. -- [ ] **Tests** — Mock-based tests for all helpers, error injection. + - `CollectionStats(ctx, store VectorStore, name string) (*CollectionInfo, error)` — point count, vector size, status. Needs `CollectionInfo` struct (not Qdrant-specific). (cf26e88) +- [x] **Add `ListCollections` and `DeleteCollection` to VectorStore interface** — Currently these methods exist on `QdrantClient` but NOT on the `VectorStore` interface. Add them and update mock. (cf26e88) +- [x] **Tests** — Mock-based tests for all helpers, error injection. (cf26e88) ### 3.3 Keyword Pre-Filter -- [ ] **Create `keyword.go`** — `KeywordFilter(results []QueryResult, keywords []string) []QueryResult` — re-ranks results by boosting scores for results containing query keywords. Pure string matching (case-insensitive `strings.Contains`). +- [x] **Create `keyword.go`** — `KeywordFilter(results []QueryResult, keywords []string) []QueryResult` — re-ranks results by boosting scores for results containing query keywords. Pure string matching (case-insensitive `strings.Contains`). - Boost formula: `score *= 1.0 + 0.1 * matchCount` (each keyword match adds 10% boost) - - Re-sort by boosted score descending -- [ ] **Add `Keywords bool` to QueryConfig** — When true, extract keywords from query text and apply KeywordFilter after vector search. -- [ ] **Tests** — (a) No keywords (passthrough), (b) single keyword boost, (c) multiple keywords, (d) case insensitive, (e) no matches (scores unchanged). + - Re-sort by boosted score descending (cf26e88) +- [x] **Add `Keywords bool` to QueryConfig** — When true, extract keywords from query text and apply KeywordFilter after vector search. (cf26e88) +- [x] **Tests** — (a) No keywords (passthrough), (b) single keyword boost, (c) multiple keywords, (d) case insensitive, (e) no matches (scores unchanged). (cf26e88) ### 3.4 Benchmarks -- [ ] **Create `benchmark_test.go`** — No build tag (mock-only): +- [x] **Create `benchmark_test.go`** — No build tag (mock-only): - `BenchmarkChunk` — 10KB markdown document, default config - `BenchmarkChunkWithOverlap` — Same document, overlap=100 - `BenchmarkQuery_Mock` — Query with mock embedder + mock store - `BenchmarkIngest_Mock` — Ingest 10 files with mock embedder + mock store - `BenchmarkFormatResults` — FormatResultsText/Context/JSON with 20 results - - `BenchmarkKeywordFilter` — 100 results, 5 keywords + - `BenchmarkKeywordFilter` — 100 results, 5 keywords (cf26e88) ## Phase 4: GPU Embeddings diff --git a/benchmark_test.go b/benchmark_test.go new file mode 100644 index 0000000..1652b48 --- /dev/null +++ b/benchmark_test.go @@ -0,0 +1,170 @@ +package rag + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" +) + +// generateMarkdownDoc creates a ~10KB markdown document for benchmarking. +func generateMarkdownDoc() string { + var sb strings.Builder + sb.WriteString("# Benchmark Document\n\n") + sb.WriteString("This document is generated for benchmarking the chunking pipeline.\n\n") + + for i := 0; i < 20; i++ { + sb.WriteString(fmt.Sprintf("## Section %d\n\n", i+1)) + for j := 0; j < 5; j++ { + sb.WriteString(fmt.Sprintf( + "Paragraph %d in section %d contains representative text for testing. "+ + "It includes multiple sentences to exercise the sentence-aware splitter. "+ + "The content is deliberately verbose to reach a realistic document size. "+ + "Each paragraph is unique to avoid deduplication effects.\n\n", + j+1, i+1)) + } + } + + return sb.String() +} + +// generateQueryResults creates n QueryResult entries for benchmarking. +func generateQueryResults(n int) []QueryResult { + results := make([]QueryResult, n) + for i := range results { + results[i] = QueryResult{ + Text: fmt.Sprintf("This is result number %d with some representative text content for testing purposes.", i+1), + Source: fmt.Sprintf("docs/section-%d/file-%d.md", i/5, i), + Section: fmt.Sprintf("Section %d", i+1), + Category: "documentation", + ChunkIndex: i, + Score: 1.0 - float32(i)*0.01, + } + } + return results +} + +// --- BenchmarkChunk --- + +func BenchmarkChunk(b *testing.B) { + doc := generateMarkdownDoc() + cfg := DefaultChunkConfig() + + b.ResetTimer() + for b.Loop() { + ChunkMarkdown(doc, cfg) + } +} + +// --- BenchmarkChunkWithOverlap --- + +func BenchmarkChunkWithOverlap(b *testing.B) { + doc := generateMarkdownDoc() + cfg := ChunkConfig{Size: 500, Overlap: 100} + + b.ResetTimer() + for b.Loop() { + ChunkMarkdown(doc, cfg) + } +} + +// --- BenchmarkQuery_Mock --- + +func BenchmarkQuery_Mock(b *testing.B) { + store := newMockVectorStore() + store.collections["bench-col"] = 768 + // Pre-populate with 50 points + for i := 0; i < 50; i++ { + store.points["bench-col"] = append(store.points["bench-col"], Point{ + ID: fmt.Sprintf("p%d", i), + Vector: make([]float32, 768), + Payload: map[string]any{ + "text": fmt.Sprintf("Benchmark document %d with relevant content.", i), + "source": fmt.Sprintf("doc%d.md", i), + "section": "Section", + "category": "docs", + "chunk_index": i, + }, + }) + } + embedder := newMockEmbedder(768) + cfg := DefaultQueryConfig() + cfg.Collection = "bench-col" + cfg.Threshold = 0.0 + ctx := context.Background() + + b.ResetTimer() + for b.Loop() { + _, _ = Query(ctx, store, embedder, "benchmark query text", cfg) + } +} + +// --- BenchmarkIngest_Mock --- + +func BenchmarkIngest_Mock(b *testing.B) { + dir := b.TempDir() + // Create 10 markdown files + for i := 0; i < 10; i++ { + content := fmt.Sprintf("## File %d\n\nThis is file number %d with some test content for benchmarking.\n", i, i) + path := filepath.Join(dir, fmt.Sprintf("doc%d.md", i)) + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + b.Fatal(err) + } + } + + ctx := context.Background() + + b.ResetTimer() + for b.Loop() { + store := newMockVectorStore() + embedder := newMockEmbedder(768) + cfg := DefaultIngestConfig() + cfg.Directory = dir + cfg.Collection = "bench-ingest" + + _, _ = Ingest(ctx, store, embedder, cfg, nil) + } +} + +// --- BenchmarkFormatResults --- + +func BenchmarkFormatResultsText(b *testing.B) { + results := generateQueryResults(20) + + b.ResetTimer() + for b.Loop() { + FormatResultsText(results) + } +} + +func BenchmarkFormatResultsContext(b *testing.B) { + results := generateQueryResults(20) + + b.ResetTimer() + for b.Loop() { + FormatResultsContext(results) + } +} + +func BenchmarkFormatResultsJSON(b *testing.B) { + results := generateQueryResults(20) + + b.ResetTimer() + for b.Loop() { + FormatResultsJSON(results) + } +} + +// --- BenchmarkKeywordFilter --- + +func BenchmarkKeywordFilter(b *testing.B) { + results := generateQueryResults(100) + keywords := []string{"result", "content", "testing", "documentation", "section"} + + b.ResetTimer() + for b.Loop() { + KeywordFilter(results, keywords) + } +} diff --git a/chunk.go b/chunk.go index fbcc3c9..ca2df52 100644 --- a/chunk.go +++ b/chunk.go @@ -30,7 +30,9 @@ type Chunk struct { } // ChunkMarkdown splits markdown text into chunks by sections and paragraphs. -// Preserves context with configurable overlap. +// Preserves context with configurable overlap. When a paragraph exceeds the +// configured Size, it is split at sentence boundaries. Overlap is aligned to +// word boundaries to avoid splitting mid-word. func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk { if cfg.Size <= 0 { cfg.Size = 500 @@ -80,28 +82,40 @@ func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk { continue } - if len(currentChunk)+len(para)+2 <= cfg.Size { - if currentChunk != "" { - currentChunk += "\n\n" + para - } else { - currentChunk = para + // If the paragraph itself exceeds Size, split at sentence + // boundaries and treat each sentence (or group of sentences) + // as a separate sub-paragraph. + subParas := []string{para} + if len(para) > cfg.Size { + if sentences := splitBySentences(para); len(sentences) > 1 { + subParas = sentences } - } else { - if currentChunk != "" { - chunks = append(chunks, Chunk{ - Text: strings.TrimSpace(currentChunk), - Section: title, - Index: chunkIndex, - }) - chunkIndex++ + } + + for _, sp := range subParas { + sp = strings.TrimSpace(sp) + if sp == "" { + continue } - // Start new chunk with overlap from previous (rune-safe for UTF-8) - runes := []rune(currentChunk) - if cfg.Overlap > 0 && len(runes) > cfg.Overlap { - overlapText := string(runes[len(runes)-cfg.Overlap:]) - currentChunk = overlapText + "\n\n" + para + + if len(currentChunk)+len(sp)+2 <= cfg.Size { + if currentChunk != "" { + currentChunk += "\n\n" + sp + } else { + currentChunk = sp + } } else { - currentChunk = para + if currentChunk != "" { + chunks = append(chunks, Chunk{ + Text: strings.TrimSpace(currentChunk), + Section: title, + Index: chunkIndex, + }) + chunkIndex++ + } + // Start new chunk with overlap from previous, + // aligned to the nearest word boundary. + currentChunk = overlapPrefix(currentChunk, cfg.Overlap, sp) } } } @@ -120,6 +134,76 @@ func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk { return chunks } +// overlapPrefix builds the start of a new chunk by taking word-boundary-aligned +// overlap text from the previous chunk and prepending it to the new paragraph. +func overlapPrefix(prevChunk string, overlap int, newPara string) string { + if overlap <= 0 { + return newPara + } + + runes := []rune(prevChunk) + if len(runes) <= overlap { + return newPara + } + + // Slice from the end of the previous chunk + overlapRunes := runes[len(runes)-overlap:] + + // Align to the nearest word boundary: find the first space within the + // overlap slice and start after it to avoid a partial leading word. + overlapText := string(overlapRunes) + if idx := strings.IndexByte(overlapText, ' '); idx >= 0 { + overlapText = overlapText[idx+1:] + } + + if overlapText == "" { + return newPara + } + + return overlapText + "\n\n" + newPara +} + +// splitBySentences splits text at sentence boundaries (". ", "? ", "! "). +// Returns the original text in a single-element slice when no boundaries are found. +func splitBySentences(text string) []string { + var sentences []string + remaining := text + + for len(remaining) > 0 { + // Find the earliest sentence boundary + bestIdx := -1 + var bestSep string + for _, sep := range []string{". ", "? ", "! "} { + idx := strings.Index(remaining, sep) + if idx >= 0 && (bestIdx < 0 || idx < bestIdx) { + bestIdx = idx + bestSep = sep + } + } + + if bestIdx < 0 { + // No more boundaries — append remainder + sentences = append(sentences, remaining) + break + } + + // Include the punctuation mark in the sentence, but not the trailing space + sentence := remaining[:bestIdx+len(bestSep)-1] + sentences = append(sentences, strings.TrimSpace(sentence)) + remaining = remaining[bestIdx+len(bestSep):] + } + + // Filter out empty entries + var filtered []string + for _, s := range sentences { + if strings.TrimSpace(s) != "" { + filtered = append(filtered, s) + } + } + + return filtered +} + // splitBySections splits text by ## headers while preserving the header with its content. func splitBySections(text string) []string { var sections []string diff --git a/chunk_test.go b/chunk_test.go index 6bd2c4a..479049a 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -1,6 +1,7 @@ package rag import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -348,3 +349,139 @@ func joinParagraphs(parts []string) string { } return result } + +// --- Phase 3.1: Sentence splitting and overlap alignment --- + +func TestChunkMarkdown_SentenceSplitting(t *testing.T) { + t.Run("oversized paragraph split at sentence boundaries", func(t *testing.T) { + // Three sentences, each ~60 chars. Total ~180 chars exceeds Size=100. + s1 := "The quick brown fox jumps over the lazy dog on the green hill." + s2 := "A second sentence that also has a reasonable amount of words." + s3 := "Finally the third sentence wraps up this oversized paragraph." + text := "## Section\n\n" + s1 + " " + s2 + " " + s3 + cfg := ChunkConfig{Size: 100, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + // Should produce more than one chunk because sentences are split + assert.Greater(t, len(chunks), 1, "oversized paragraph should be split into multiple chunks") + + // Verify all original text appears across the chunks + combined := "" + for _, c := range chunks { + combined += c.Text + " " + } + assert.Contains(t, combined, "quick brown fox") + assert.Contains(t, combined, "second sentence") + assert.Contains(t, combined, "third sentence") + + // Each chunk should have the correct section + for _, c := range chunks { + assert.Equal(t, "Section", c.Section) + } + }) + + t.Run("paragraph without sentence boundaries kept as single chunk", func(t *testing.T) { + // Long paragraph with no sentence-ending punctuation followed by space + para := repeatString("word ", 100) // ~500 chars, no ". " or "? " or "! " + text := "## S\n\n" + para + cfg := ChunkConfig{Size: 100, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + // Should still produce at least one chunk (fallback behaviour) + assert.NotEmpty(t, chunks) + }) + + t.Run("sentence boundaries preserve punctuation", func(t *testing.T) { + text := "## S\n\nFirst sentence. Second sentence. Third sentence." + cfg := ChunkConfig{Size: 30, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + // The first chunk should end with a period (punctuation preserved) + foundPeriod := false + for _, c := range chunks { + if strings.HasSuffix(strings.TrimSpace(c.Text), ".") { + foundPeriod = true + break + } + } + assert.True(t, foundPeriod, "at least one chunk should end with a period") + }) +} + +func TestChunkMarkdown_OverlapWordBoundary(t *testing.T) { + t.Run("overlap does not split mid-word", func(t *testing.T) { + // Build two paragraphs where the first is large enough to emit, + // and the overlap region lands mid-word in the naive rune slice. + para1 := "Alpha bravo charlie delta echo foxtrot golf hotel india juliet kilo lima mike november oscar papa quebec romeo sierra tango." + para2 := "Uniform victor whiskey xray yankee zulu." + text := "## S\n\n" + para1 + "\n\n" + para2 + cfg := ChunkConfig{Size: 80, Overlap: 15} + chunks := ChunkMarkdown(text, cfg) + + // Find a chunk that contains overlap text (not the first chunk) + for i, c := range chunks { + if i == 0 { + continue + } + // The overlap prefix should start at a word boundary: + // it should not begin with a partial word fragment. + words := strings.Fields(c.Text) + if len(words) > 0 { + // The first word should be a recognisable whole word, not a suffix + // of a longer word. We can verify there is no leading lowercase + // fragment by checking the original text contains this word. + firstWord := words[0] + assert.True(t, + strings.Contains(para1, firstWord) || strings.Contains(para2, firstWord), + "overlap should start at a word boundary, got leading word: %q", firstWord) + } + } + }) + + t.Run("overlap with zero value produces no overlap", func(t *testing.T) { + para1 := repeatString("Abcdef. ", 30) // ~240 chars + para2 := "Unique marker text here." + text := "## S\n\n" + para1 + "\n\n" + para2 + cfg := ChunkConfig{Size: 100, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + // With zero overlap, the second chunk should not contain text from + // the end of the previous chunk + assert.NotEmpty(t, chunks) + }) +} + +func TestSplitBySentences(t *testing.T) { + t.Run("splits on period-space", func(t *testing.T) { + result := splitBySentences("First. Second. Third.") + assert.Len(t, result, 3) + assert.Equal(t, "First.", result[0]) + assert.Equal(t, "Second.", result[1]) + assert.Equal(t, "Third.", result[2]) + }) + + t.Run("splits on question mark", func(t *testing.T) { + result := splitBySentences("What is this? It is a test.") + assert.Len(t, result, 2) + assert.Equal(t, "What is this?", result[0]) + assert.Equal(t, "It is a test.", result[1]) + }) + + t.Run("splits on exclamation mark", func(t *testing.T) { + result := splitBySentences("Wow! That is amazing.") + assert.Len(t, result, 2) + assert.Equal(t, "Wow!", result[0]) + assert.Equal(t, "That is amazing.", result[1]) + }) + + t.Run("no boundaries returns single element", func(t *testing.T) { + result := splitBySentences("just a plain string with no ending") + assert.Len(t, result, 1) + assert.Equal(t, "just a plain string with no ending", result[0]) + }) + + t.Run("empty string returns empty", func(t *testing.T) { + result := splitBySentences("") + assert.Empty(t, result) + }) +} diff --git a/collections.go b/collections.go new file mode 100644 index 0000000..aabf756 --- /dev/null +++ b/collections.go @@ -0,0 +1,18 @@ +package rag + +import "context" + +// ListCollections returns all collection names from the given vector store. +func ListCollections(ctx context.Context, store VectorStore) ([]string, error) { + return store.ListCollections(ctx) +} + +// DeleteCollection removes a collection from the given vector store. +func DeleteCollection(ctx context.Context, store VectorStore, name string) error { + return store.DeleteCollection(ctx, name) +} + +// CollectionStats returns backend-agnostic metadata about a collection. +func CollectionStats(ctx context.Context, store VectorStore, name string) (*CollectionInfo, error) { + return store.CollectionInfo(ctx, name) +} diff --git a/collections_test.go b/collections_test.go new file mode 100644 index 0000000..2a4835f --- /dev/null +++ b/collections_test.go @@ -0,0 +1,127 @@ +package rag + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- ListCollections tests --- + +func TestListCollections(t *testing.T) { + t.Run("returns collection names from store", func(t *testing.T) { + store := newMockVectorStore() + store.collections["alpha"] = 768 + store.collections["bravo"] = 384 + + names, err := ListCollections(context.Background(), store) + + require.NoError(t, err) + assert.Len(t, names, 2) + assert.Contains(t, names, "alpha") + assert.Contains(t, names, "bravo") + }) + + t.Run("empty store returns empty list", func(t *testing.T) { + store := newMockVectorStore() + + names, err := ListCollections(context.Background(), store) + + require.NoError(t, err) + assert.Empty(t, names) + }) + + t.Run("error from store propagates", func(t *testing.T) { + store := newMockVectorStore() + store.listErr = fmt.Errorf("connection lost") + + _, err := ListCollections(context.Background(), store) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "connection lost") + }) +} + +// --- DeleteCollection tests --- + +func TestDeleteCollectionHelper(t *testing.T) { + t.Run("deletes collection from store", func(t *testing.T) { + store := newMockVectorStore() + store.collections["to-delete"] = 768 + store.points["to-delete"] = []Point{{ID: "p1"}} + + err := DeleteCollection(context.Background(), store, "to-delete") + + require.NoError(t, err) + _, exists := store.collections["to-delete"] + assert.False(t, exists) + assert.Empty(t, store.points["to-delete"]) + }) + + t.Run("error from store propagates", func(t *testing.T) { + store := newMockVectorStore() + store.deleteErr = fmt.Errorf("permission denied") + + err := DeleteCollection(context.Background(), store, "any") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "permission denied") + }) +} + +// --- CollectionStats tests --- + +func TestCollectionStats(t *testing.T) { + t.Run("returns info for existing collection", func(t *testing.T) { + store := newMockVectorStore() + store.collections["my-col"] = 768 + store.points["my-col"] = []Point{ + {ID: "p1", Vector: []float32{0.1}}, + {ID: "p2", Vector: []float32{0.2}}, + {ID: "p3", Vector: []float32{0.3}}, + } + + info, err := CollectionStats(context.Background(), store, "my-col") + + require.NoError(t, err) + require.NotNil(t, info) + assert.Equal(t, "my-col", info.Name) + assert.Equal(t, uint64(3), info.PointCount) + assert.Equal(t, uint64(768), info.VectorSize) + assert.Equal(t, "green", info.Status) + }) + + t.Run("nonexistent collection returns error", func(t *testing.T) { + store := newMockVectorStore() + + _, err := CollectionStats(context.Background(), store, "missing") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("error from store propagates", func(t *testing.T) { + store := newMockVectorStore() + store.collections["err-col"] = 768 + store.infoErr = fmt.Errorf("internal error") + + _, err := CollectionStats(context.Background(), store, "err-col") + + assert.Error(t, err) + assert.Contains(t, err.Error(), "internal error") + }) + + t.Run("empty collection has zero point count", func(t *testing.T) { + store := newMockVectorStore() + store.collections["empty-col"] = 384 + + info, err := CollectionStats(context.Background(), store, "empty-col") + + require.NoError(t, err) + assert.Equal(t, uint64(0), info.PointCount) + assert.Equal(t, uint64(384), info.VectorSize) + }) +} diff --git a/keyword.go b/keyword.go new file mode 100644 index 0000000..d0ad6dd --- /dev/null +++ b/keyword.go @@ -0,0 +1,60 @@ +package rag + +import ( + "sort" + "strings" +) + +// KeywordFilter re-ranks query results by boosting scores for results whose +// text contains one or more of the given keywords. Matching is +// case-insensitive using strings.Contains. Each keyword match adds a 10% +// boost to the original score: score *= 1.0 + 0.1 * matchCount. +// Results are re-sorted by boosted score descending. +func KeywordFilter(results []QueryResult, keywords []string) []QueryResult { + if len(keywords) == 0 || len(results) == 0 { + return results + } + + // Normalise keywords to lowercase once + lowerKeywords := make([]string, len(keywords)) + for i, kw := range keywords { + lowerKeywords[i] = strings.ToLower(kw) + } + + // Apply boost + boosted := make([]QueryResult, len(results)) + copy(boosted, results) + + for i := range boosted { + lowerText := strings.ToLower(boosted[i].Text) + matchCount := 0 + for _, kw := range lowerKeywords { + if kw != "" && strings.Contains(lowerText, kw) { + matchCount++ + } + } + if matchCount > 0 { + boosted[i].Score *= 1.0 + 0.1*float32(matchCount) + } + } + + // Re-sort by boosted score descending + sort.Slice(boosted, func(i, j int) bool { + return boosted[i].Score > boosted[j].Score + }) + + return boosted +} + +// extractKeywords splits query text into individual keywords for filtering. +// Words shorter than 3 characters are discarded as they tend to be noise. +func extractKeywords(query string) []string { + words := strings.Fields(strings.ToLower(query)) + var keywords []string + for _, w := range words { + if len(w) >= 3 { + keywords = append(keywords, w) + } + } + return keywords +} diff --git a/keyword_test.go b/keyword_test.go new file mode 100644 index 0000000..45bd9f4 --- /dev/null +++ b/keyword_test.go @@ -0,0 +1,216 @@ +package rag + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- KeywordFilter tests --- + +func TestKeywordFilter(t *testing.T) { + t.Run("no keywords returns results unchanged", func(t *testing.T) { + results := []QueryResult{ + {Text: "Hello world.", Score: 0.9}, + {Text: "Goodbye world.", Score: 0.8}, + } + + filtered := KeywordFilter(results, nil) + + require.Len(t, filtered, 2) + assert.Equal(t, float32(0.9), filtered[0].Score) + assert.Equal(t, float32(0.8), filtered[1].Score) + }) + + t.Run("empty keywords returns results unchanged", func(t *testing.T) { + results := []QueryResult{ + {Text: "Hello world.", Score: 0.9}, + } + + filtered := KeywordFilter(results, []string{}) + + require.Len(t, filtered, 1) + assert.Equal(t, float32(0.9), filtered[0].Score) + }) + + t.Run("single keyword boosts matching result", func(t *testing.T) { + results := []QueryResult{ + {Text: "This document is about Go programming.", Score: 0.8}, + {Text: "This document is about Python scripting.", Score: 0.9}, + } + + filtered := KeywordFilter(results, []string{"Go"}) + + require.Len(t, filtered, 2) + // Go result should be boosted by 10%: 0.8 * 1.1 = 0.88 + // Python result unchanged: 0.9 + // Python (0.9) > Go (0.88), so Python still first + assert.Equal(t, "This document is about Python scripting.", filtered[0].Text) + assert.InDelta(t, 0.9, filtered[0].Score, 0.001) + assert.Equal(t, "This document is about Go programming.", filtered[1].Text) + assert.InDelta(t, 0.88, filtered[1].Score, 0.001) + }) + + t.Run("single keyword can reorder results", func(t *testing.T) { + results := []QueryResult{ + {Text: "General information about various topics.", Score: 0.85}, + {Text: "Detailed guide to Kubernetes deployment.", Score: 0.80}, + } + + filtered := KeywordFilter(results, []string{"kubernetes"}) + + require.Len(t, filtered, 2) + // Kubernetes result boosted: 0.80 * 1.1 = 0.88 > 0.85 + assert.Equal(t, "Detailed guide to Kubernetes deployment.", filtered[0].Text) + assert.InDelta(t, 0.88, filtered[0].Score, 0.001) + }) + + t.Run("multiple keywords compound boost", func(t *testing.T) { + results := []QueryResult{ + {Text: "Go is a programming language for systems.", Score: 0.7}, + {Text: "Python is used for machine learning tasks.", Score: 0.9}, + {Text: "Go and Rust are systems programming languages.", Score: 0.6}, + } + + filtered := KeywordFilter(results, []string{"go", "systems"}) + + require.Len(t, filtered, 3) + // First result matches both: 0.7 * 1.2 = 0.84 + // Second result matches neither: 0.9 + // Third result matches both: 0.6 * 1.2 = 0.72 + // Order: Python (0.9), first Go (0.84), third Go+Rust (0.72) + assert.Equal(t, "Python is used for machine learning tasks.", filtered[0].Text) + assert.InDelta(t, 0.9, filtered[0].Score, 0.001) + assert.Equal(t, "Go is a programming language for systems.", filtered[1].Text) + assert.InDelta(t, 0.84, filtered[1].Score, 0.001) + assert.Equal(t, "Go and Rust are systems programming languages.", filtered[2].Text) + assert.InDelta(t, 0.72, filtered[2].Score, 0.001) + }) + + t.Run("case insensitive matching", func(t *testing.T) { + results := []QueryResult{ + {Text: "KUBERNETES is a container orchestration platform.", Score: 0.7}, + {Text: "Docker runs containers.", Score: 0.8}, + } + + filtered := KeywordFilter(results, []string{"kubernetes"}) + + require.Len(t, filtered, 2) + // KUBERNETES matches "kubernetes" case-insensitively: 0.7 * 1.1 = 0.77 + assert.InDelta(t, 0.77, filtered[1].Score, 0.001) + assert.Equal(t, "KUBERNETES is a container orchestration platform.", filtered[1].Text) + }) + + t.Run("no matches leaves scores unchanged", func(t *testing.T) { + results := []QueryResult{ + {Text: "This is about cats.", Score: 0.9}, + {Text: "This is about dogs.", Score: 0.8}, + } + + filtered := KeywordFilter(results, []string{"elephants"}) + + require.Len(t, filtered, 2) + assert.Equal(t, float32(0.9), filtered[0].Score) + assert.Equal(t, float32(0.8), filtered[1].Score) + assert.Equal(t, "This is about cats.", filtered[0].Text) + assert.Equal(t, "This is about dogs.", filtered[1].Text) + }) + + t.Run("empty results returns empty", func(t *testing.T) { + filtered := KeywordFilter(nil, []string{"test"}) + assert.Empty(t, filtered) + }) +} + +// --- extractKeywords tests --- + +func TestExtractKeywords(t *testing.T) { + t.Run("extracts words 3+ characters", func(t *testing.T) { + keywords := extractKeywords("how do I use Go modules") + assert.Contains(t, keywords, "how") + assert.Contains(t, keywords, "use") + assert.Contains(t, keywords, "modules") + // "do" and "I" are too short + assert.NotContains(t, keywords, "do") + assert.NotContains(t, keywords, "i") + }) + + t.Run("empty string returns empty", func(t *testing.T) { + keywords := extractKeywords("") + assert.Empty(t, keywords) + }) + + t.Run("all short words returns empty", func(t *testing.T) { + keywords := extractKeywords("I am a") + assert.Empty(t, keywords) + }) + + t.Run("keywords are lowercased", func(t *testing.T) { + keywords := extractKeywords("Kubernetes Deployment") + assert.Contains(t, keywords, "kubernetes") + assert.Contains(t, keywords, "deployment") + }) +} + +// --- Query with Keywords integration --- + +func TestQuery_Keywords(t *testing.T) { + t.Run("keywords flag enables keyword boosting", func(t *testing.T) { + store := newMockVectorStore() + store.points["test-col"] = []Point{ + {ID: "1", Vector: []float32{0.1}, Payload: map[string]any{ + "text": "General overview of the platform.", "source": "a.md", + "section": "", "category": "docs", "chunk_index": 0, + }}, + {ID: "2", Vector: []float32{0.1}, Payload: map[string]any{ + "text": "Guide to deploying with Kubernetes containers.", "source": "b.md", + "section": "", "category": "docs", "chunk_index": 1, + }}, + } + embedder := newMockEmbedder(768) + + cfg := DefaultQueryConfig() + cfg.Collection = "test-col" + cfg.Limit = 10 + cfg.Threshold = 0.0 + cfg.Keywords = true + + results, err := Query(context.Background(), store, embedder, "kubernetes containers", cfg) + + require.NoError(t, err) + require.Len(t, results, 2) + // The second result (score 0.9 from mock) matches two keywords, + // boosted to 0.9 * 1.2 = 1.08, so it should be first. + assert.Equal(t, "Guide to deploying with Kubernetes containers.", results[0].Text) + }) + + t.Run("keywords false does not boost", func(t *testing.T) { + store := newMockVectorStore() + store.points["test-col"] = []Point{ + {ID: "1", Vector: []float32{0.1}, Payload: map[string]any{ + "text": "First result text.", "source": "a.md", + "section": "", "category": "docs", "chunk_index": 0, + }}, + {ID: "2", Vector: []float32{0.1}, Payload: map[string]any{ + "text": "Second result text with keywords.", "source": "b.md", + "section": "", "category": "docs", "chunk_index": 1, + }}, + } + embedder := newMockEmbedder(768) + + cfg := DefaultQueryConfig() + cfg.Collection = "test-col" + cfg.Limit = 10 + cfg.Threshold = 0.0 + cfg.Keywords = false + + results, err := Query(context.Background(), store, embedder, "keywords", cfg) + + require.NoError(t, err) + require.Len(t, results, 2) + // Without keywords, original order preserved (first has higher score) + assert.Equal(t, "First result text.", results[0].Text) + }) +} diff --git a/mock_test.go b/mock_test.go index 11418f0..8f025b2 100644 --- a/mock_test.go +++ b/mock_test.go @@ -91,6 +91,8 @@ type mockVectorStore struct { createCalls []createCollectionCall existsCalls []string deleteCalls []string + listCalls int + infoCalls []string upsertCalls []upsertCall searchCalls []searchCall @@ -98,6 +100,8 @@ type mockVectorStore struct { createErr error existsErr error deleteErr error + listErr error + infoErr error upsertErr error searchErr error } @@ -169,6 +173,49 @@ func (m *mockVectorStore) DeleteCollection(ctx context.Context, name string) err return nil } +func (m *mockVectorStore) ListCollections(ctx context.Context) ([]string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.listCalls++ + + if m.listErr != nil { + return nil, m.listErr + } + + names := make([]string, 0, len(m.collections)) + for name := range m.collections { + names = append(names, name) + } + sort.Strings(names) + return names, nil +} + +func (m *mockVectorStore) CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.infoCalls = append(m.infoCalls, name) + + if m.infoErr != nil { + return nil, m.infoErr + } + + vectorSize, exists := m.collections[name] + if !exists { + return nil, fmt.Errorf("collection %q not found", name) + } + + pointCount := uint64(len(m.points[name])) + + return &CollectionInfo{ + Name: name, + PointCount: pointCount, + VectorSize: vectorSize, + Status: "green", + }, nil +} + func (m *mockVectorStore) UpsertPoints(ctx context.Context, collection string, points []Point) error { m.mu.Lock() defer m.mu.Unlock() diff --git a/qdrant.go b/qdrant.go index 14a540e..53e55a2 100644 --- a/qdrant.go +++ b/qdrant.go @@ -97,9 +97,36 @@ func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error return q.client.DeleteCollection(ctx, name) } -// CollectionInfo returns information about a collection. -func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) { - return q.client.GetCollectionInfo(ctx, name) +// CollectionInfo returns backend-agnostic metadata about a collection. +func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error) { + info, err := q.client.GetCollectionInfo(ctx, name) + if err != nil { + return nil, err + } + + ci := &CollectionInfo{ + Name: name, + PointCount: info.GetPointsCount(), + } + + // Extract vector size from the Qdrant config + if params := info.GetConfig().GetParams().GetVectorsConfig().GetParams(); params != nil { + ci.VectorSize = params.GetSize() + } + + // Map Qdrant status to a simple string + switch info.Status { + case qdrant.CollectionStatus_Green: + ci.Status = "green" + case qdrant.CollectionStatus_Yellow: + ci.Status = "yellow" + case qdrant.CollectionStatus_Red: + ci.Status = "red" + default: + ci.Status = "unknown" + } + + return ci, nil } // Point represents a vector point with payload. diff --git a/query.go b/query.go index 33fa67c..7a7ff5d 100644 --- a/query.go +++ b/query.go @@ -15,6 +15,7 @@ type QueryConfig struct { Limit uint64 Threshold float32 // Minimum similarity score (0-1) Category string // Filter by category + Keywords bool // When true, extract keywords from query and boost matching results } // DefaultQueryConfig returns default query configuration. @@ -93,6 +94,14 @@ func Query(ctx context.Context, store VectorStore, embedder Embedder, query stri queryResults = append(queryResults, qr) } + // Apply keyword boosting when enabled + if cfg.Keywords && len(queryResults) > 0 { + keywords := extractKeywords(query) + if len(keywords) > 0 { + queryResults = KeywordFilter(queryResults, keywords) + } + } + return queryResults, nil } diff --git a/vectorstore.go b/vectorstore.go index 6944b9e..c1c2305 100644 --- a/vectorstore.go +++ b/vectorstore.go @@ -15,6 +15,14 @@ type VectorStore interface { // DeleteCollection deletes the collection with the given name. DeleteCollection(ctx context.Context, name string) error + // ListCollections returns all collection names in the store. + ListCollections(ctx context.Context) ([]string, error) + + // CollectionInfo returns metadata about a collection. Implementations + // should populate at least PointCount and VectorSize in the returned + // CollectionInfo struct. + CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error) + // UpsertPoints inserts or updates points in the named collection. UpsertPoints(ctx context.Context, collection string, points []Point) error @@ -22,3 +30,11 @@ type VectorStore interface { // An optional filter map restricts results by payload field values. Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error) } + +// CollectionInfo holds backend-agnostic metadata about a collection. +type CollectionInfo struct { + Name string + PointCount uint64 + VectorSize uint64 + Status string // e.g. "green", "yellow", "red", "unknown" +}