go-rag/keyword.go
Snider d188b05ad8
Some checks failed
Security Scan / security (push) Successful in 10s
Test / test (push) Failing after 39s
feat: modernise to Go 1.26 iterators and stdlib helpers
Add ChunkMarkdownSeq, QuerySeq, KeywordFilterSeq, ListCollectionsSeq
iterators for streaming. Use slices.SortFunc, slices.Contains,
slices.Collect in mock/query/keyword. Range-over-int in benchmarks.

Co-Authored-By: Gemini <noreply@google.com>
Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-23 05:36:39 +00:00

87 lines
2.1 KiB
Go

package rag
import (
"iter"
"slices"
"strings"
)
// KeywordFilter re-ranks query results by boosting scores for results whose
// text contains one or more of the given keywords. Matching is
// case-insensitive using strings.Contains. Each keyword match adds a 10%
// boost to the original score: score *= 1.0 + 0.1 * matchCount.
// Results are re-sorted by boosted score descending.
func KeywordFilter(results []QueryResult, keywords []string) []QueryResult {
if len(keywords) == 0 || len(results) == 0 {
return results
}
// Normalise keywords to lowercase once
lowerKeywords := slices.Collect(func(yield func(string) bool) {
for _, kw := range keywords {
if !yield(strings.ToLower(kw)) {
return
}
}
})
// Apply boost
boosted := make([]QueryResult, len(results))
copy(boosted, results)
for i := range boosted {
lowerText := strings.ToLower(boosted[i].Text)
matchCount := 0
for _, kw := range lowerKeywords {
if kw != "" && strings.Contains(lowerText, kw) {
matchCount++
}
}
if matchCount > 0 {
boosted[i].Score *= 1.0 + 0.1*float32(matchCount)
}
}
// Re-sort by boosted score descending
slices.SortFunc(boosted, func(a, b QueryResult) int {
if a.Score > b.Score {
return -1
} else if a.Score < b.Score {
return 1
}
return 0
})
return boosted
}
// KeywordFilterSeq is an iterator version of KeywordFilter.
func KeywordFilterSeq(results []QueryResult, keywords []string) iter.Seq[QueryResult] {
return func(yield func(QueryResult) bool) {
filtered := KeywordFilter(results, keywords)
for _, r := range filtered {
if !yield(r) {
return
}
}
}
}
// extractKeywords splits query text into individual keywords for filtering.
// Words shorter than 3 characters are discarded as they tend to be noise.
func extractKeywords(query string) []string {
return slices.Collect(extractKeywordsSeq(query))
}
// extractKeywordsSeq returns an iterator that yields keywords from a query.
func extractKeywordsSeq(query string) iter.Seq[string] {
return func(yield func(string) bool) {
for w := range strings.FieldsSeq(strings.ToLower(query)) {
if len(w) >= 3 {
if !yield(w) {
return
}
}
}
}
}