feat: Phase 3 enhancements — sentence splitting, collection helpers, keyword filter, benchmarks

3.1: Sentence-aware chunk splitting at ". ", "? ", "! " boundaries when
paragraphs exceed ChunkConfig.Size. Overlap now aligns to word boundaries
to avoid mid-word splits.

3.2: VectorStore interface gains ListCollections and CollectionInfo methods.
New collections.go with ListCollections, DeleteCollection, CollectionStats
helpers returning backend-agnostic CollectionInfo. Mock updated accordingly.

3.3: KeywordFilter re-ranks QueryResults by boosting scores for keyword
matches (case-insensitive, +10% per keyword). QueryConfig.Keywords flag
enables automatic extraction and filtering.

3.4: Mock-only benchmarks for chunking, query, ingest, formatting, and
keyword filtering.

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Snider 2026-02-20 08:01:41 +00:00
parent cf1c6191c4
commit d8fd067a8c
12 changed files with 947 additions and 36 deletions

26
TODO.md
View file

@ -43,36 +43,36 @@ All tasks are pure Go, testable with existing mocks. No external services needed
### 3.1 Chunk Boundary Improvements
- [ ] **Sentence-aware splitting** — When a paragraph exceeds `ChunkConfig.Size`, split at sentence boundaries (`. `, `? `, `! `) instead of adding the whole paragraph as an oversized chunk. Keep current behaviour as fallback when no sentence boundaries exist.
- [ ] **Overlap boundary alignment** — Current overlap slices by rune count from the end of the previous chunk. Improve by aligning overlap to word boundaries (find the nearest space before the overlap point) to avoid splitting mid-word.
- [ ] **Tests** — (a) Sentence splitting with 3 sentences > Size, (b) overlap word boundary alignment, (c) existing tests still pass (no regression).
- [x] **Sentence-aware splitting** — When a paragraph exceeds `ChunkConfig.Size`, split at sentence boundaries (`. `, `? `, `! `) instead of adding the whole paragraph as an oversized chunk. Keep current behaviour as fallback when no sentence boundaries exist. (cf26e88)
- [x] **Overlap boundary alignment** — Current overlap slices by rune count from the end of the previous chunk. Improve by aligning overlap to word boundaries (find the nearest space before the overlap point) to avoid splitting mid-word. (cf26e88)
- [x] **Tests** — (a) Sentence splitting with 3 sentences > Size, (b) overlap word boundary alignment, (c) existing tests still pass (no regression). (cf26e88)
### 3.2 Collection Management Helpers
- [ ] **Create `collections.go`** — Helper functions for collection lifecycle:
- [x] **Create `collections.go`** — Helper functions for collection lifecycle:
- `ListCollections(ctx, store VectorStore) ([]string, error)` — wraps store method
- `DeleteCollection(ctx, store VectorStore, name string) error` — wraps store method
- `CollectionStats(ctx, store VectorStore, name string) (*CollectionInfo, error)` — point count, vector size, status. Needs `CollectionInfo` struct (not Qdrant-specific).
- [ ] **Add `ListCollections` and `DeleteCollection` to VectorStore interface** — Currently these methods exist on `QdrantClient` but NOT on the `VectorStore` interface. Add them and update mock.
- [ ] **Tests** — Mock-based tests for all helpers, error injection.
- `CollectionStats(ctx, store VectorStore, name string) (*CollectionInfo, error)` — point count, vector size, status. Needs `CollectionInfo` struct (not Qdrant-specific). (cf26e88)
- [x] **Add `ListCollections` and `DeleteCollection` to VectorStore interface** — Currently these methods exist on `QdrantClient` but NOT on the `VectorStore` interface. Add them and update mock. (cf26e88)
- [x] **Tests** — Mock-based tests for all helpers, error injection. (cf26e88)
### 3.3 Keyword Pre-Filter
- [ ] **Create `keyword.go`**`KeywordFilter(results []QueryResult, keywords []string) []QueryResult` — re-ranks results by boosting scores for results containing query keywords. Pure string matching (case-insensitive `strings.Contains`).
- [x] **Create `keyword.go`**`KeywordFilter(results []QueryResult, keywords []string) []QueryResult` — re-ranks results by boosting scores for results containing query keywords. Pure string matching (case-insensitive `strings.Contains`).
- Boost formula: `score *= 1.0 + 0.1 * matchCount` (each keyword match adds 10% boost)
- Re-sort by boosted score descending
- [ ] **Add `Keywords bool` to QueryConfig** — When true, extract keywords from query text and apply KeywordFilter after vector search.
- [ ] **Tests** — (a) No keywords (passthrough), (b) single keyword boost, (c) multiple keywords, (d) case insensitive, (e) no matches (scores unchanged).
- Re-sort by boosted score descending (cf26e88)
- [x] **Add `Keywords bool` to QueryConfig** — When true, extract keywords from query text and apply KeywordFilter after vector search. (cf26e88)
- [x] **Tests** — (a) No keywords (passthrough), (b) single keyword boost, (c) multiple keywords, (d) case insensitive, (e) no matches (scores unchanged). (cf26e88)
### 3.4 Benchmarks
- [ ] **Create `benchmark_test.go`** — No build tag (mock-only):
- [x] **Create `benchmark_test.go`** — No build tag (mock-only):
- `BenchmarkChunk` — 10KB markdown document, default config
- `BenchmarkChunkWithOverlap` — Same document, overlap=100
- `BenchmarkQuery_Mock` — Query with mock embedder + mock store
- `BenchmarkIngest_Mock` — Ingest 10 files with mock embedder + mock store
- `BenchmarkFormatResults` — FormatResultsText/Context/JSON with 20 results
- `BenchmarkKeywordFilter` — 100 results, 5 keywords
- `BenchmarkKeywordFilter` — 100 results, 5 keywords (cf26e88)
## Phase 4: GPU Embeddings

170
benchmark_test.go Normal file
View file

@ -0,0 +1,170 @@
package rag
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
)
// generateMarkdownDoc creates a ~10KB markdown document for benchmarking.
func generateMarkdownDoc() string {
var sb strings.Builder
sb.WriteString("# Benchmark Document\n\n")
sb.WriteString("This document is generated for benchmarking the chunking pipeline.\n\n")
for i := 0; i < 20; i++ {
sb.WriteString(fmt.Sprintf("## Section %d\n\n", i+1))
for j := 0; j < 5; j++ {
sb.WriteString(fmt.Sprintf(
"Paragraph %d in section %d contains representative text for testing. "+
"It includes multiple sentences to exercise the sentence-aware splitter. "+
"The content is deliberately verbose to reach a realistic document size. "+
"Each paragraph is unique to avoid deduplication effects.\n\n",
j+1, i+1))
}
}
return sb.String()
}
// generateQueryResults creates n QueryResult entries for benchmarking.
func generateQueryResults(n int) []QueryResult {
results := make([]QueryResult, n)
for i := range results {
results[i] = QueryResult{
Text: fmt.Sprintf("This is result number %d with some representative text content for testing purposes.", i+1),
Source: fmt.Sprintf("docs/section-%d/file-%d.md", i/5, i),
Section: fmt.Sprintf("Section %d", i+1),
Category: "documentation",
ChunkIndex: i,
Score: 1.0 - float32(i)*0.01,
}
}
return results
}
// --- BenchmarkChunk ---
func BenchmarkChunk(b *testing.B) {
doc := generateMarkdownDoc()
cfg := DefaultChunkConfig()
b.ResetTimer()
for b.Loop() {
ChunkMarkdown(doc, cfg)
}
}
// --- BenchmarkChunkWithOverlap ---
func BenchmarkChunkWithOverlap(b *testing.B) {
doc := generateMarkdownDoc()
cfg := ChunkConfig{Size: 500, Overlap: 100}
b.ResetTimer()
for b.Loop() {
ChunkMarkdown(doc, cfg)
}
}
// --- BenchmarkQuery_Mock ---
func BenchmarkQuery_Mock(b *testing.B) {
store := newMockVectorStore()
store.collections["bench-col"] = 768
// Pre-populate with 50 points
for i := 0; i < 50; i++ {
store.points["bench-col"] = append(store.points["bench-col"], Point{
ID: fmt.Sprintf("p%d", i),
Vector: make([]float32, 768),
Payload: map[string]any{
"text": fmt.Sprintf("Benchmark document %d with relevant content.", i),
"source": fmt.Sprintf("doc%d.md", i),
"section": "Section",
"category": "docs",
"chunk_index": i,
},
})
}
embedder := newMockEmbedder(768)
cfg := DefaultQueryConfig()
cfg.Collection = "bench-col"
cfg.Threshold = 0.0
ctx := context.Background()
b.ResetTimer()
for b.Loop() {
_, _ = Query(ctx, store, embedder, "benchmark query text", cfg)
}
}
// --- BenchmarkIngest_Mock ---
func BenchmarkIngest_Mock(b *testing.B) {
dir := b.TempDir()
// Create 10 markdown files
for i := 0; i < 10; i++ {
content := fmt.Sprintf("## File %d\n\nThis is file number %d with some test content for benchmarking.\n", i, i)
path := filepath.Join(dir, fmt.Sprintf("doc%d.md", i))
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
b.Fatal(err)
}
}
ctx := context.Background()
b.ResetTimer()
for b.Loop() {
store := newMockVectorStore()
embedder := newMockEmbedder(768)
cfg := DefaultIngestConfig()
cfg.Directory = dir
cfg.Collection = "bench-ingest"
_, _ = Ingest(ctx, store, embedder, cfg, nil)
}
}
// --- BenchmarkFormatResults ---
func BenchmarkFormatResultsText(b *testing.B) {
results := generateQueryResults(20)
b.ResetTimer()
for b.Loop() {
FormatResultsText(results)
}
}
func BenchmarkFormatResultsContext(b *testing.B) {
results := generateQueryResults(20)
b.ResetTimer()
for b.Loop() {
FormatResultsContext(results)
}
}
func BenchmarkFormatResultsJSON(b *testing.B) {
results := generateQueryResults(20)
b.ResetTimer()
for b.Loop() {
FormatResultsJSON(results)
}
}
// --- BenchmarkKeywordFilter ---
func BenchmarkKeywordFilter(b *testing.B) {
results := generateQueryResults(100)
keywords := []string{"result", "content", "testing", "documentation", "section"}
b.ResetTimer()
for b.Loop() {
KeywordFilter(results, keywords)
}
}

124
chunk.go
View file

@ -30,7 +30,9 @@ type Chunk struct {
}
// ChunkMarkdown splits markdown text into chunks by sections and paragraphs.
// Preserves context with configurable overlap.
// Preserves context with configurable overlap. When a paragraph exceeds the
// configured Size, it is split at sentence boundaries. Overlap is aligned to
// word boundaries to avoid splitting mid-word.
func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk {
if cfg.Size <= 0 {
cfg.Size = 500
@ -80,28 +82,40 @@ func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk {
continue
}
if len(currentChunk)+len(para)+2 <= cfg.Size {
if currentChunk != "" {
currentChunk += "\n\n" + para
} else {
currentChunk = para
// If the paragraph itself exceeds Size, split at sentence
// boundaries and treat each sentence (or group of sentences)
// as a separate sub-paragraph.
subParas := []string{para}
if len(para) > cfg.Size {
if sentences := splitBySentences(para); len(sentences) > 1 {
subParas = sentences
}
} else {
if currentChunk != "" {
chunks = append(chunks, Chunk{
Text: strings.TrimSpace(currentChunk),
Section: title,
Index: chunkIndex,
})
chunkIndex++
}
for _, sp := range subParas {
sp = strings.TrimSpace(sp)
if sp == "" {
continue
}
// Start new chunk with overlap from previous (rune-safe for UTF-8)
runes := []rune(currentChunk)
if cfg.Overlap > 0 && len(runes) > cfg.Overlap {
overlapText := string(runes[len(runes)-cfg.Overlap:])
currentChunk = overlapText + "\n\n" + para
if len(currentChunk)+len(sp)+2 <= cfg.Size {
if currentChunk != "" {
currentChunk += "\n\n" + sp
} else {
currentChunk = sp
}
} else {
currentChunk = para
if currentChunk != "" {
chunks = append(chunks, Chunk{
Text: strings.TrimSpace(currentChunk),
Section: title,
Index: chunkIndex,
})
chunkIndex++
}
// Start new chunk with overlap from previous,
// aligned to the nearest word boundary.
currentChunk = overlapPrefix(currentChunk, cfg.Overlap, sp)
}
}
}
@ -120,6 +134,76 @@ func ChunkMarkdown(text string, cfg ChunkConfig) []Chunk {
return chunks
}
// overlapPrefix builds the start of a new chunk by taking word-boundary-aligned
// overlap text from the previous chunk and prepending it to the new paragraph.
func overlapPrefix(prevChunk string, overlap int, newPara string) string {
if overlap <= 0 {
return newPara
}
runes := []rune(prevChunk)
if len(runes) <= overlap {
return newPara
}
// Slice from the end of the previous chunk
overlapRunes := runes[len(runes)-overlap:]
// Align to the nearest word boundary: find the first space within the
// overlap slice and start after it to avoid a partial leading word.
overlapText := string(overlapRunes)
if idx := strings.IndexByte(overlapText, ' '); idx >= 0 {
overlapText = overlapText[idx+1:]
}
if overlapText == "" {
return newPara
}
return overlapText + "\n\n" + newPara
}
// splitBySentences splits text at sentence boundaries (". ", "? ", "! ").
// Returns the original text in a single-element slice when no boundaries are found.
func splitBySentences(text string) []string {
var sentences []string
remaining := text
for len(remaining) > 0 {
// Find the earliest sentence boundary
bestIdx := -1
var bestSep string
for _, sep := range []string{". ", "? ", "! "} {
idx := strings.Index(remaining, sep)
if idx >= 0 && (bestIdx < 0 || idx < bestIdx) {
bestIdx = idx
bestSep = sep
}
}
if bestIdx < 0 {
// No more boundaries — append remainder
sentences = append(sentences, remaining)
break
}
// Include the punctuation mark in the sentence, but not the trailing space
sentence := remaining[:bestIdx+len(bestSep)-1]
sentences = append(sentences, strings.TrimSpace(sentence))
remaining = remaining[bestIdx+len(bestSep):]
}
// Filter out empty entries
var filtered []string
for _, s := range sentences {
if strings.TrimSpace(s) != "" {
filtered = append(filtered, s)
}
}
return filtered
}
// splitBySections splits text by ## headers while preserving the header with its content.
func splitBySections(text string) []string {
var sections []string

View file

@ -1,6 +1,7 @@
package rag
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
@ -348,3 +349,139 @@ func joinParagraphs(parts []string) string {
}
return result
}
// --- Phase 3.1: Sentence splitting and overlap alignment ---
func TestChunkMarkdown_SentenceSplitting(t *testing.T) {
t.Run("oversized paragraph split at sentence boundaries", func(t *testing.T) {
// Three sentences, each ~60 chars. Total ~180 chars exceeds Size=100.
s1 := "The quick brown fox jumps over the lazy dog on the green hill."
s2 := "A second sentence that also has a reasonable amount of words."
s3 := "Finally the third sentence wraps up this oversized paragraph."
text := "## Section\n\n" + s1 + " " + s2 + " " + s3
cfg := ChunkConfig{Size: 100, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
// Should produce more than one chunk because sentences are split
assert.Greater(t, len(chunks), 1, "oversized paragraph should be split into multiple chunks")
// Verify all original text appears across the chunks
combined := ""
for _, c := range chunks {
combined += c.Text + " "
}
assert.Contains(t, combined, "quick brown fox")
assert.Contains(t, combined, "second sentence")
assert.Contains(t, combined, "third sentence")
// Each chunk should have the correct section
for _, c := range chunks {
assert.Equal(t, "Section", c.Section)
}
})
t.Run("paragraph without sentence boundaries kept as single chunk", func(t *testing.T) {
// Long paragraph with no sentence-ending punctuation followed by space
para := repeatString("word ", 100) // ~500 chars, no ". " or "? " or "! "
text := "## S\n\n" + para
cfg := ChunkConfig{Size: 100, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
// Should still produce at least one chunk (fallback behaviour)
assert.NotEmpty(t, chunks)
})
t.Run("sentence boundaries preserve punctuation", func(t *testing.T) {
text := "## S\n\nFirst sentence. Second sentence. Third sentence."
cfg := ChunkConfig{Size: 30, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
// The first chunk should end with a period (punctuation preserved)
foundPeriod := false
for _, c := range chunks {
if strings.HasSuffix(strings.TrimSpace(c.Text), ".") {
foundPeriod = true
break
}
}
assert.True(t, foundPeriod, "at least one chunk should end with a period")
})
}
func TestChunkMarkdown_OverlapWordBoundary(t *testing.T) {
t.Run("overlap does not split mid-word", func(t *testing.T) {
// Build two paragraphs where the first is large enough to emit,
// and the overlap region lands mid-word in the naive rune slice.
para1 := "Alpha bravo charlie delta echo foxtrot golf hotel india juliet kilo lima mike november oscar papa quebec romeo sierra tango."
para2 := "Uniform victor whiskey xray yankee zulu."
text := "## S\n\n" + para1 + "\n\n" + para2
cfg := ChunkConfig{Size: 80, Overlap: 15}
chunks := ChunkMarkdown(text, cfg)
// Find a chunk that contains overlap text (not the first chunk)
for i, c := range chunks {
if i == 0 {
continue
}
// The overlap prefix should start at a word boundary:
// it should not begin with a partial word fragment.
words := strings.Fields(c.Text)
if len(words) > 0 {
// The first word should be a recognisable whole word, not a suffix
// of a longer word. We can verify there is no leading lowercase
// fragment by checking the original text contains this word.
firstWord := words[0]
assert.True(t,
strings.Contains(para1, firstWord) || strings.Contains(para2, firstWord),
"overlap should start at a word boundary, got leading word: %q", firstWord)
}
}
})
t.Run("overlap with zero value produces no overlap", func(t *testing.T) {
para1 := repeatString("Abcdef. ", 30) // ~240 chars
para2 := "Unique marker text here."
text := "## S\n\n" + para1 + "\n\n" + para2
cfg := ChunkConfig{Size: 100, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
// With zero overlap, the second chunk should not contain text from
// the end of the previous chunk
assert.NotEmpty(t, chunks)
})
}
func TestSplitBySentences(t *testing.T) {
t.Run("splits on period-space", func(t *testing.T) {
result := splitBySentences("First. Second. Third.")
assert.Len(t, result, 3)
assert.Equal(t, "First.", result[0])
assert.Equal(t, "Second.", result[1])
assert.Equal(t, "Third.", result[2])
})
t.Run("splits on question mark", func(t *testing.T) {
result := splitBySentences("What is this? It is a test.")
assert.Len(t, result, 2)
assert.Equal(t, "What is this?", result[0])
assert.Equal(t, "It is a test.", result[1])
})
t.Run("splits on exclamation mark", func(t *testing.T) {
result := splitBySentences("Wow! That is amazing.")
assert.Len(t, result, 2)
assert.Equal(t, "Wow!", result[0])
assert.Equal(t, "That is amazing.", result[1])
})
t.Run("no boundaries returns single element", func(t *testing.T) {
result := splitBySentences("just a plain string with no ending")
assert.Len(t, result, 1)
assert.Equal(t, "just a plain string with no ending", result[0])
})
t.Run("empty string returns empty", func(t *testing.T) {
result := splitBySentences("")
assert.Empty(t, result)
})
}

18
collections.go Normal file
View file

@ -0,0 +1,18 @@
package rag
import "context"
// ListCollections returns all collection names from the given vector store.
func ListCollections(ctx context.Context, store VectorStore) ([]string, error) {
return store.ListCollections(ctx)
}
// DeleteCollection removes a collection from the given vector store.
func DeleteCollection(ctx context.Context, store VectorStore, name string) error {
return store.DeleteCollection(ctx, name)
}
// CollectionStats returns backend-agnostic metadata about a collection.
func CollectionStats(ctx context.Context, store VectorStore, name string) (*CollectionInfo, error) {
return store.CollectionInfo(ctx, name)
}

127
collections_test.go Normal file
View file

@ -0,0 +1,127 @@
package rag
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- ListCollections tests ---
func TestListCollections(t *testing.T) {
t.Run("returns collection names from store", func(t *testing.T) {
store := newMockVectorStore()
store.collections["alpha"] = 768
store.collections["bravo"] = 384
names, err := ListCollections(context.Background(), store)
require.NoError(t, err)
assert.Len(t, names, 2)
assert.Contains(t, names, "alpha")
assert.Contains(t, names, "bravo")
})
t.Run("empty store returns empty list", func(t *testing.T) {
store := newMockVectorStore()
names, err := ListCollections(context.Background(), store)
require.NoError(t, err)
assert.Empty(t, names)
})
t.Run("error from store propagates", func(t *testing.T) {
store := newMockVectorStore()
store.listErr = fmt.Errorf("connection lost")
_, err := ListCollections(context.Background(), store)
assert.Error(t, err)
assert.Contains(t, err.Error(), "connection lost")
})
}
// --- DeleteCollection tests ---
func TestDeleteCollectionHelper(t *testing.T) {
t.Run("deletes collection from store", func(t *testing.T) {
store := newMockVectorStore()
store.collections["to-delete"] = 768
store.points["to-delete"] = []Point{{ID: "p1"}}
err := DeleteCollection(context.Background(), store, "to-delete")
require.NoError(t, err)
_, exists := store.collections["to-delete"]
assert.False(t, exists)
assert.Empty(t, store.points["to-delete"])
})
t.Run("error from store propagates", func(t *testing.T) {
store := newMockVectorStore()
store.deleteErr = fmt.Errorf("permission denied")
err := DeleteCollection(context.Background(), store, "any")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
})
}
// --- CollectionStats tests ---
func TestCollectionStats(t *testing.T) {
t.Run("returns info for existing collection", func(t *testing.T) {
store := newMockVectorStore()
store.collections["my-col"] = 768
store.points["my-col"] = []Point{
{ID: "p1", Vector: []float32{0.1}},
{ID: "p2", Vector: []float32{0.2}},
{ID: "p3", Vector: []float32{0.3}},
}
info, err := CollectionStats(context.Background(), store, "my-col")
require.NoError(t, err)
require.NotNil(t, info)
assert.Equal(t, "my-col", info.Name)
assert.Equal(t, uint64(3), info.PointCount)
assert.Equal(t, uint64(768), info.VectorSize)
assert.Equal(t, "green", info.Status)
})
t.Run("nonexistent collection returns error", func(t *testing.T) {
store := newMockVectorStore()
_, err := CollectionStats(context.Background(), store, "missing")
assert.Error(t, err)
assert.Contains(t, err.Error(), "not found")
})
t.Run("error from store propagates", func(t *testing.T) {
store := newMockVectorStore()
store.collections["err-col"] = 768
store.infoErr = fmt.Errorf("internal error")
_, err := CollectionStats(context.Background(), store, "err-col")
assert.Error(t, err)
assert.Contains(t, err.Error(), "internal error")
})
t.Run("empty collection has zero point count", func(t *testing.T) {
store := newMockVectorStore()
store.collections["empty-col"] = 384
info, err := CollectionStats(context.Background(), store, "empty-col")
require.NoError(t, err)
assert.Equal(t, uint64(0), info.PointCount)
assert.Equal(t, uint64(384), info.VectorSize)
})
}

60
keyword.go Normal file
View file

@ -0,0 +1,60 @@
package rag
import (
"sort"
"strings"
)
// KeywordFilter re-ranks query results by boosting scores for results whose
// text contains one or more of the given keywords. Matching is
// case-insensitive using strings.Contains. Each keyword match adds a 10%
// boost to the original score: score *= 1.0 + 0.1 * matchCount.
// Results are re-sorted by boosted score descending.
func KeywordFilter(results []QueryResult, keywords []string) []QueryResult {
if len(keywords) == 0 || len(results) == 0 {
return results
}
// Normalise keywords to lowercase once
lowerKeywords := make([]string, len(keywords))
for i, kw := range keywords {
lowerKeywords[i] = strings.ToLower(kw)
}
// Apply boost
boosted := make([]QueryResult, len(results))
copy(boosted, results)
for i := range boosted {
lowerText := strings.ToLower(boosted[i].Text)
matchCount := 0
for _, kw := range lowerKeywords {
if kw != "" && strings.Contains(lowerText, kw) {
matchCount++
}
}
if matchCount > 0 {
boosted[i].Score *= 1.0 + 0.1*float32(matchCount)
}
}
// Re-sort by boosted score descending
sort.Slice(boosted, func(i, j int) bool {
return boosted[i].Score > boosted[j].Score
})
return boosted
}
// extractKeywords splits query text into individual keywords for filtering.
// Words shorter than 3 characters are discarded as they tend to be noise.
func extractKeywords(query string) []string {
words := strings.Fields(strings.ToLower(query))
var keywords []string
for _, w := range words {
if len(w) >= 3 {
keywords = append(keywords, w)
}
}
return keywords
}

216
keyword_test.go Normal file
View file

@ -0,0 +1,216 @@
package rag
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- KeywordFilter tests ---
func TestKeywordFilter(t *testing.T) {
t.Run("no keywords returns results unchanged", func(t *testing.T) {
results := []QueryResult{
{Text: "Hello world.", Score: 0.9},
{Text: "Goodbye world.", Score: 0.8},
}
filtered := KeywordFilter(results, nil)
require.Len(t, filtered, 2)
assert.Equal(t, float32(0.9), filtered[0].Score)
assert.Equal(t, float32(0.8), filtered[1].Score)
})
t.Run("empty keywords returns results unchanged", func(t *testing.T) {
results := []QueryResult{
{Text: "Hello world.", Score: 0.9},
}
filtered := KeywordFilter(results, []string{})
require.Len(t, filtered, 1)
assert.Equal(t, float32(0.9), filtered[0].Score)
})
t.Run("single keyword boosts matching result", func(t *testing.T) {
results := []QueryResult{
{Text: "This document is about Go programming.", Score: 0.8},
{Text: "This document is about Python scripting.", Score: 0.9},
}
filtered := KeywordFilter(results, []string{"Go"})
require.Len(t, filtered, 2)
// Go result should be boosted by 10%: 0.8 * 1.1 = 0.88
// Python result unchanged: 0.9
// Python (0.9) > Go (0.88), so Python still first
assert.Equal(t, "This document is about Python scripting.", filtered[0].Text)
assert.InDelta(t, 0.9, filtered[0].Score, 0.001)
assert.Equal(t, "This document is about Go programming.", filtered[1].Text)
assert.InDelta(t, 0.88, filtered[1].Score, 0.001)
})
t.Run("single keyword can reorder results", func(t *testing.T) {
results := []QueryResult{
{Text: "General information about various topics.", Score: 0.85},
{Text: "Detailed guide to Kubernetes deployment.", Score: 0.80},
}
filtered := KeywordFilter(results, []string{"kubernetes"})
require.Len(t, filtered, 2)
// Kubernetes result boosted: 0.80 * 1.1 = 0.88 > 0.85
assert.Equal(t, "Detailed guide to Kubernetes deployment.", filtered[0].Text)
assert.InDelta(t, 0.88, filtered[0].Score, 0.001)
})
t.Run("multiple keywords compound boost", func(t *testing.T) {
results := []QueryResult{
{Text: "Go is a programming language for systems.", Score: 0.7},
{Text: "Python is used for machine learning tasks.", Score: 0.9},
{Text: "Go and Rust are systems programming languages.", Score: 0.6},
}
filtered := KeywordFilter(results, []string{"go", "systems"})
require.Len(t, filtered, 3)
// First result matches both: 0.7 * 1.2 = 0.84
// Second result matches neither: 0.9
// Third result matches both: 0.6 * 1.2 = 0.72
// Order: Python (0.9), first Go (0.84), third Go+Rust (0.72)
assert.Equal(t, "Python is used for machine learning tasks.", filtered[0].Text)
assert.InDelta(t, 0.9, filtered[0].Score, 0.001)
assert.Equal(t, "Go is a programming language for systems.", filtered[1].Text)
assert.InDelta(t, 0.84, filtered[1].Score, 0.001)
assert.Equal(t, "Go and Rust are systems programming languages.", filtered[2].Text)
assert.InDelta(t, 0.72, filtered[2].Score, 0.001)
})
t.Run("case insensitive matching", func(t *testing.T) {
results := []QueryResult{
{Text: "KUBERNETES is a container orchestration platform.", Score: 0.7},
{Text: "Docker runs containers.", Score: 0.8},
}
filtered := KeywordFilter(results, []string{"kubernetes"})
require.Len(t, filtered, 2)
// KUBERNETES matches "kubernetes" case-insensitively: 0.7 * 1.1 = 0.77
assert.InDelta(t, 0.77, filtered[1].Score, 0.001)
assert.Equal(t, "KUBERNETES is a container orchestration platform.", filtered[1].Text)
})
t.Run("no matches leaves scores unchanged", func(t *testing.T) {
results := []QueryResult{
{Text: "This is about cats.", Score: 0.9},
{Text: "This is about dogs.", Score: 0.8},
}
filtered := KeywordFilter(results, []string{"elephants"})
require.Len(t, filtered, 2)
assert.Equal(t, float32(0.9), filtered[0].Score)
assert.Equal(t, float32(0.8), filtered[1].Score)
assert.Equal(t, "This is about cats.", filtered[0].Text)
assert.Equal(t, "This is about dogs.", filtered[1].Text)
})
t.Run("empty results returns empty", func(t *testing.T) {
filtered := KeywordFilter(nil, []string{"test"})
assert.Empty(t, filtered)
})
}
// --- extractKeywords tests ---
func TestExtractKeywords(t *testing.T) {
t.Run("extracts words 3+ characters", func(t *testing.T) {
keywords := extractKeywords("how do I use Go modules")
assert.Contains(t, keywords, "how")
assert.Contains(t, keywords, "use")
assert.Contains(t, keywords, "modules")
// "do" and "I" are too short
assert.NotContains(t, keywords, "do")
assert.NotContains(t, keywords, "i")
})
t.Run("empty string returns empty", func(t *testing.T) {
keywords := extractKeywords("")
assert.Empty(t, keywords)
})
t.Run("all short words returns empty", func(t *testing.T) {
keywords := extractKeywords("I am a")
assert.Empty(t, keywords)
})
t.Run("keywords are lowercased", func(t *testing.T) {
keywords := extractKeywords("Kubernetes Deployment")
assert.Contains(t, keywords, "kubernetes")
assert.Contains(t, keywords, "deployment")
})
}
// --- Query with Keywords integration ---
func TestQuery_Keywords(t *testing.T) {
t.Run("keywords flag enables keyword boosting", func(t *testing.T) {
store := newMockVectorStore()
store.points["test-col"] = []Point{
{ID: "1", Vector: []float32{0.1}, Payload: map[string]any{
"text": "General overview of the platform.", "source": "a.md",
"section": "", "category": "docs", "chunk_index": 0,
}},
{ID: "2", Vector: []float32{0.1}, Payload: map[string]any{
"text": "Guide to deploying with Kubernetes containers.", "source": "b.md",
"section": "", "category": "docs", "chunk_index": 1,
}},
}
embedder := newMockEmbedder(768)
cfg := DefaultQueryConfig()
cfg.Collection = "test-col"
cfg.Limit = 10
cfg.Threshold = 0.0
cfg.Keywords = true
results, err := Query(context.Background(), store, embedder, "kubernetes containers", cfg)
require.NoError(t, err)
require.Len(t, results, 2)
// The second result (score 0.9 from mock) matches two keywords,
// boosted to 0.9 * 1.2 = 1.08, so it should be first.
assert.Equal(t, "Guide to deploying with Kubernetes containers.", results[0].Text)
})
t.Run("keywords false does not boost", func(t *testing.T) {
store := newMockVectorStore()
store.points["test-col"] = []Point{
{ID: "1", Vector: []float32{0.1}, Payload: map[string]any{
"text": "First result text.", "source": "a.md",
"section": "", "category": "docs", "chunk_index": 0,
}},
{ID: "2", Vector: []float32{0.1}, Payload: map[string]any{
"text": "Second result text with keywords.", "source": "b.md",
"section": "", "category": "docs", "chunk_index": 1,
}},
}
embedder := newMockEmbedder(768)
cfg := DefaultQueryConfig()
cfg.Collection = "test-col"
cfg.Limit = 10
cfg.Threshold = 0.0
cfg.Keywords = false
results, err := Query(context.Background(), store, embedder, "keywords", cfg)
require.NoError(t, err)
require.Len(t, results, 2)
// Without keywords, original order preserved (first has higher score)
assert.Equal(t, "First result text.", results[0].Text)
})
}

View file

@ -91,6 +91,8 @@ type mockVectorStore struct {
createCalls []createCollectionCall
existsCalls []string
deleteCalls []string
listCalls int
infoCalls []string
upsertCalls []upsertCall
searchCalls []searchCall
@ -98,6 +100,8 @@ type mockVectorStore struct {
createErr error
existsErr error
deleteErr error
listErr error
infoErr error
upsertErr error
searchErr error
}
@ -169,6 +173,49 @@ func (m *mockVectorStore) DeleteCollection(ctx context.Context, name string) err
return nil
}
func (m *mockVectorStore) ListCollections(ctx context.Context) ([]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.listCalls++
if m.listErr != nil {
return nil, m.listErr
}
names := make([]string, 0, len(m.collections))
for name := range m.collections {
names = append(names, name)
}
sort.Strings(names)
return names, nil
}
func (m *mockVectorStore) CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.infoCalls = append(m.infoCalls, name)
if m.infoErr != nil {
return nil, m.infoErr
}
vectorSize, exists := m.collections[name]
if !exists {
return nil, fmt.Errorf("collection %q not found", name)
}
pointCount := uint64(len(m.points[name]))
return &CollectionInfo{
Name: name,
PointCount: pointCount,
VectorSize: vectorSize,
Status: "green",
}, nil
}
func (m *mockVectorStore) UpsertPoints(ctx context.Context, collection string, points []Point) error {
m.mu.Lock()
defer m.mu.Unlock()

View file

@ -97,9 +97,36 @@ func (q *QdrantClient) DeleteCollection(ctx context.Context, name string) error
return q.client.DeleteCollection(ctx, name)
}
// CollectionInfo returns information about a collection.
func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*qdrant.CollectionInfo, error) {
return q.client.GetCollectionInfo(ctx, name)
// CollectionInfo returns backend-agnostic metadata about a collection.
func (q *QdrantClient) CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error) {
info, err := q.client.GetCollectionInfo(ctx, name)
if err != nil {
return nil, err
}
ci := &CollectionInfo{
Name: name,
PointCount: info.GetPointsCount(),
}
// Extract vector size from the Qdrant config
if params := info.GetConfig().GetParams().GetVectorsConfig().GetParams(); params != nil {
ci.VectorSize = params.GetSize()
}
// Map Qdrant status to a simple string
switch info.Status {
case qdrant.CollectionStatus_Green:
ci.Status = "green"
case qdrant.CollectionStatus_Yellow:
ci.Status = "yellow"
case qdrant.CollectionStatus_Red:
ci.Status = "red"
default:
ci.Status = "unknown"
}
return ci, nil
}
// Point represents a vector point with payload.

View file

@ -15,6 +15,7 @@ type QueryConfig struct {
Limit uint64
Threshold float32 // Minimum similarity score (0-1)
Category string // Filter by category
Keywords bool // When true, extract keywords from query and boost matching results
}
// DefaultQueryConfig returns default query configuration.
@ -93,6 +94,14 @@ func Query(ctx context.Context, store VectorStore, embedder Embedder, query stri
queryResults = append(queryResults, qr)
}
// Apply keyword boosting when enabled
if cfg.Keywords && len(queryResults) > 0 {
keywords := extractKeywords(query)
if len(keywords) > 0 {
queryResults = KeywordFilter(queryResults, keywords)
}
}
return queryResults, nil
}

View file

@ -15,6 +15,14 @@ type VectorStore interface {
// DeleteCollection deletes the collection with the given name.
DeleteCollection(ctx context.Context, name string) error
// ListCollections returns all collection names in the store.
ListCollections(ctx context.Context) ([]string, error)
// CollectionInfo returns metadata about a collection. Implementations
// should populate at least PointCount and VectorSize in the returned
// CollectionInfo struct.
CollectionInfo(ctx context.Context, name string) (*CollectionInfo, error)
// UpsertPoints inserts or updates points in the named collection.
UpsertPoints(ctx context.Context, collection string, points []Point) error
@ -22,3 +30,11 @@ type VectorStore interface {
// An optional filter map restricts results by payload field values.
Search(ctx context.Context, collection string, vector []float32, limit uint64, filter map[string]string) ([]SearchResult, error)
}
// CollectionInfo holds backend-agnostic metadata about a collection.
type CollectionInfo struct {
Name string
PointCount uint64
VectorSize uint64
Status string // e.g. "green", "yellow", "red", "unknown"
}