package rag import ( "context" "fmt" "html" "iter" "slices" "strings" "forge.lthn.ai/core/go/pkg/log" ) // QueryConfig holds query configuration. type QueryConfig struct { Collection string Limit uint64 Threshold float32 // Minimum similarity score (0-1) Category string // Filter by category Keywords bool // When true, extract keywords from query and boost matching results } // DefaultQueryConfig returns default query configuration. func DefaultQueryConfig() QueryConfig { return QueryConfig{ Collection: "hostuk-docs", Limit: 5, Threshold: 0.5, } } // QueryResult represents a query result with metadata. type QueryResult struct { Text string Source string Section string Category string ChunkIndex int Score float32 } // Query searches for similar documents in the vector store. func Query(ctx context.Context, store VectorStore, embedder Embedder, query string, cfg QueryConfig) ([]QueryResult, error) { it, err := QuerySeq(ctx, store, embedder, query, cfg) if err != nil { return nil, err } return slices.Collect(it), nil } // QuerySeq returns an iterator that yields query results from the vector store. func QuerySeq(ctx context.Context, store VectorStore, embedder Embedder, query string, cfg QueryConfig) (iter.Seq[QueryResult], error) { // Generate embedding for query embedding, err := embedder.Embed(ctx, query) if err != nil { return nil, log.E("rag.Query", "error generating query embedding", err) } // Build filter var filter map[string]string if cfg.Category != "" { filter = map[string]string{"category": cfg.Category} } // Search vector store results, err := store.Search(ctx, cfg.Collection, embedding, cfg.Limit, filter) if err != nil { return nil, log.E("rag.Query", "error searching", err) } // Filter by threshold and convert to iterator filteredIt := func(yield func(QueryResult) bool) { var queryResults []QueryResult for _, r := range results { if r.Score < cfg.Threshold { continue } qr := QueryResult{ Score: r.Score, } // Extract payload fields if text, ok := r.Payload["text"].(string); ok { qr.Text = text } if source, ok := r.Payload["source"].(string); ok { qr.Source = source } if section, ok := r.Payload["section"].(string); ok { qr.Section = section } if category, ok := r.Payload["category"].(string); ok { qr.Category = category } // Handle chunk_index from various types (JSON unmarshaling produces float64) switch idx := r.Payload["chunk_index"].(type) { case int64: qr.ChunkIndex = int(idx) case float64: qr.ChunkIndex = int(idx) case int: qr.ChunkIndex = idx } queryResults = append(queryResults, qr) } // Apply keyword boosting when enabled if cfg.Keywords && len(queryResults) > 0 { keywords := extractKeywords(query) if len(keywords) > 0 { queryResults = KeywordFilter(queryResults, keywords) } } for _, qr := range queryResults { if !yield(qr) { return } } } return filteredIt, nil } // FormatResultsText formats query results as plain text. func FormatResultsText(results []QueryResult) string { if len(results) == 0 { return "No results found." } var sb strings.Builder for i, r := range results { sb.WriteString(fmt.Sprintf("\n--- Result %d (score: %.2f) ---\n", i+1, r.Score)) sb.WriteString(fmt.Sprintf("Source: %s\n", r.Source)) if r.Section != "" { sb.WriteString(fmt.Sprintf("Section: %s\n", r.Section)) } sb.WriteString(fmt.Sprintf("Category: %s\n\n", r.Category)) sb.WriteString(r.Text) sb.WriteString("\n") } return sb.String() } // FormatResultsContext formats query results for LLM context injection. func FormatResultsContext(results []QueryResult) string { if len(results) == 0 { return "" } var sb strings.Builder sb.WriteString("\n") for _, r := range results { // Escape XML special characters to prevent malformed output fmt.Fprintf(&sb, "\n", html.EscapeString(r.Source), html.EscapeString(r.Section), html.EscapeString(r.Category)) sb.WriteString(html.EscapeString(r.Text)) sb.WriteString("\n\n\n") } sb.WriteString("") return sb.String() } // FormatResultsJSON formats query results as JSON-like output. func FormatResultsJSON(results []QueryResult) string { if len(results) == 0 { return "[]" } var sb strings.Builder sb.WriteString("[\n") for i, r := range results { sb.WriteString(" {\n") sb.WriteString(fmt.Sprintf(" \"source\": %q,\n", r.Source)) sb.WriteString(fmt.Sprintf(" \"section\": %q,\n", r.Section)) sb.WriteString(fmt.Sprintf(" \"category\": %q,\n", r.Category)) sb.WriteString(fmt.Sprintf(" \"score\": %.4f,\n", r.Score)) sb.WriteString(fmt.Sprintf(" \"text\": %q\n", r.Text)) if i < len(results)-1 { sb.WriteString(" },\n") } else { sb.WriteString(" }\n") } } sb.WriteString("]") return sb.String() }