package rag
import (
"context"
"fmt"
"html"
"iter"
"slices"
"strings"
"forge.lthn.ai/core/go/pkg/log"
)
// QueryConfig holds query configuration.
type QueryConfig struct {
Collection string
Limit uint64
Threshold float32 // Minimum similarity score (0-1)
Category string // Filter by category
Keywords bool // When true, extract keywords from query and boost matching results
}
// DefaultQueryConfig returns default query configuration.
func DefaultQueryConfig() QueryConfig {
return QueryConfig{
Collection: "hostuk-docs",
Limit: 5,
Threshold: 0.5,
}
}
// QueryResult represents a query result with metadata.
type QueryResult struct {
Text string
Source string
Section string
Category string
ChunkIndex int
Score float32
}
// Query searches for similar documents in the vector store.
func Query(ctx context.Context, store VectorStore, embedder Embedder, query string, cfg QueryConfig) ([]QueryResult, error) {
it, err := QuerySeq(ctx, store, embedder, query, cfg)
if err != nil {
return nil, err
}
return slices.Collect(it), nil
}
// QuerySeq returns an iterator that yields query results from the vector store.
func QuerySeq(ctx context.Context, store VectorStore, embedder Embedder, query string, cfg QueryConfig) (iter.Seq[QueryResult], error) {
// Generate embedding for query
embedding, err := embedder.Embed(ctx, query)
if err != nil {
return nil, log.E("rag.Query", "error generating query embedding", err)
}
// Build filter
var filter map[string]string
if cfg.Category != "" {
filter = map[string]string{"category": cfg.Category}
}
// Search vector store
results, err := store.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter)
if err != nil {
return nil, log.E("rag.Query", "error searching", err)
}
// Filter by threshold and convert to iterator
filteredIt := func(yield func(QueryResult) bool) {
var queryResults []QueryResult
for _, r := range results {
if r.Score < cfg.Threshold {
continue
}
qr := QueryResult{
Score: r.Score,
}
// Extract payload fields
if text, ok := r.Payload["text"].(string); ok {
qr.Text = text
}
if source, ok := r.Payload["source"].(string); ok {
qr.Source = source
}
if section, ok := r.Payload["section"].(string); ok {
qr.Section = section
}
if category, ok := r.Payload["category"].(string); ok {
qr.Category = category
}
// Handle chunk_index from various types (JSON unmarshaling produces float64)
switch idx := r.Payload["chunk_index"].(type) {
case int64:
qr.ChunkIndex = int(idx)
case float64:
qr.ChunkIndex = int(idx)
case int:
qr.ChunkIndex = idx
}
queryResults = append(queryResults, qr)
}
// Apply keyword boosting when enabled
if cfg.Keywords && len(queryResults) > 0 {
keywords := extractKeywords(query)
if len(keywords) > 0 {
queryResults = KeywordFilter(queryResults, keywords)
}
}
for _, qr := range queryResults {
if !yield(qr) {
return
}
}
}
return filteredIt, nil
}
// FormatResultsText formats query results as plain text.
func FormatResultsText(results []QueryResult) string {
if len(results) == 0 {
return "No results found."
}
var sb strings.Builder
for i, r := range results {
sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score))
sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source))
if r.Section != "" {
sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section))
}
sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category))
sb.WriteString(r.Text)
sb.WriteString("\n")
}
return sb.String()
}
// FormatResultsContext formats query results for LLM context injection.
func FormatResultsContext(results []QueryResult) string {
if len(results) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("\n")
for _, r := range results {
// Escape XML special characters to prevent malformed output
fmt.Fprintf(&sb, "\n",
html.EscapeString(r.Source),
html.EscapeString(r.Section),
html.EscapeString(r.Category))
sb.WriteString(html.EscapeString(r.Text))
sb.WriteString("\n\n\n")
}
sb.WriteString("")
return sb.String()
}
// FormatResultsJSON formats query results as JSON-like output.
func FormatResultsJSON(results []QueryResult) string {
if len(results) == 0 {
return "[]"
}
var sb strings.Builder
sb.WriteString("[\n")
for i, r := range results {
sb.WriteString(" {\n")
sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source))
sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section))
sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category))
sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score))
sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text))
if i < len(results)-1 {
sb.WriteString(" },\n")
} else {
sb.WriteString(" }\n")
}
}
sb.WriteString("]")
return sb.String()
}