test: add Phase 1 pure-function unit tests (18.4% -> 38.8% coverage)

Add comprehensive unit tests for all pure functions that require no
external services (Qdrant, Ollama). Coverage of testable functions
now at 100% for FormatResults*, Default*Config, EmbedDimension,
valueToGo, Model, ChunkID, and ChunkMarkdown edge cases.

New test files:
- query_test.go: FormatResultsText, FormatResultsContext, FormatResultsJSON,
  DefaultQueryConfig (18 tests)
- ollama_test.go: DefaultOllamaConfig, EmbedDimension, Model (8 tests)
- qdrant_test.go: DefaultQdrantConfig, pointIDToString, valueToGo,
  Point/SearchResult structs (24 tests)

Extended chunk_test.go with edge cases:
- Empty input, whitespace-only, single newline
- Headers with no body content
- Unicode/emoji text with rune-safe overlap verification
- Very long single paragraph
- Config boundary conditions (zero/negative size, overlap >= size)
- Sequential chunk indexing
- ChunkID rune truncation with multibyte characters
- DefaultChunkConfig, DefaultIngestConfig

Co-Authored-By: Charon <developers@lethean.io>
This commit is contained in:
Claude 2026-02-20 00:02:52 +00:00
parent 034b9da45f
commit acb987a01d
No known key found for this signature in database
GPG key ID: AF404715446AEB41
4 changed files with 874 additions and 1 deletions

View file

@ -110,7 +110,225 @@ func TestShouldProcess_Good_MarkdownFiles(t *testing.T) {
assert.False(t, ShouldProcess("doc"))
}
// Helper function
// --- Additional chunk edge cases ---
func TestChunkMarkdown_Edge_EmptyInput(t *testing.T) {
t.Run("empty string returns no chunks", func(t *testing.T) {
chunks := ChunkMarkdown("", DefaultChunkConfig())
assert.Empty(t, chunks)
})
t.Run("whitespace only returns no chunks", func(t *testing.T) {
chunks := ChunkMarkdown(" \n\n \t \n ", DefaultChunkConfig())
assert.Empty(t, chunks)
})
t.Run("single newline returns no chunks", func(t *testing.T) {
chunks := ChunkMarkdown("\n", DefaultChunkConfig())
assert.Empty(t, chunks)
})
}
func TestChunkMarkdown_Edge_OnlyHeadersNoContent(t *testing.T) {
t.Run("single header with no body", func(t *testing.T) {
text := "## Just a Header\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.Len(t, chunks, 1)
assert.Equal(t, "Just a Header", chunks[0].Section)
assert.Contains(t, chunks[0].Text, "Just a Header")
})
t.Run("multiple headers with no body content", func(t *testing.T) {
text := "## Header One\n\n## Header Two\n\n## Header Three\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
// Each header becomes its own section
assert.GreaterOrEqual(t, len(chunks), 2, "should produce at least two chunks for separate sections")
})
t.Run("header hierarchy with minimal content", func(t *testing.T) {
text := "# Top Level\n\n## Sub Section\n\n### Sub Sub\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.NotEmpty(t, chunks, "should produce at least one chunk")
})
}
func TestChunkMarkdown_Edge_UnicodeAndEmoji(t *testing.T) {
t.Run("unicode text chunked correctly", func(t *testing.T) {
text := "## Unicode Section\n\nThis section has unicode: \u00e9\u00e0\u00fc\u00f1\u00f6\u00e4\u00df \u4e16\u754c \u041f\u0440\u0438\u0432\u0435\u0442 \u0645\u0631\u062d\u0628\u0627\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.Len(t, chunks, 1)
assert.Contains(t, chunks[0].Text, "\u00e9\u00e0\u00fc")
assert.Contains(t, chunks[0].Text, "\u4e16\u754c")
assert.Equal(t, "Unicode Section", chunks[0].Section)
})
t.Run("emoji text chunked correctly", func(t *testing.T) {
text := "## Emoji Section\n\nHello world! \U0001f600\U0001f680\U0001f30d\U0001f4da\n\nMore text with \u2764\ufe0f and \U0001f525 emojis.\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
assert.NotEmpty(t, chunks)
assert.Contains(t, chunks[0].Text, "\U0001f600")
assert.Equal(t, "Emoji Section", chunks[0].Section)
})
t.Run("rune-safe overlap with multibyte characters", func(t *testing.T) {
// Create text with multibyte characters that exceeds chunk size
// Each CJK character is 3 bytes in UTF-8 but 1 rune
para1 := "\u6d4b\u8bd5" + repeatString("\u4e16\u754c", 100) // ~200+ runes of CJK
para2 := "\u8fd4\u56de" + repeatString("\u4f60\u597d", 100) // ~200+ runes of CJK
text := "## CJK\n\n" + para1 + "\n\n" + para2 + "\n"
cfg := ChunkConfig{Size: 150, Overlap: 30}
chunks := ChunkMarkdown(text, cfg)
// Should not panic or produce corrupt text
assert.NotEmpty(t, chunks)
for _, chunk := range chunks {
assert.NotEmpty(t, chunk.Text)
// Verify no partial rune corruption by round-tripping through []rune
runes := []rune(chunk.Text)
assert.Equal(t, chunk.Text, string(runes), "text should survive rune round-trip without corruption")
}
})
}
func TestChunkMarkdown_Edge_VeryLongSingleParagraph(t *testing.T) {
t.Run("long paragraph without headers splits into multiple chunks", func(t *testing.T) {
// Create a very long single paragraph (no section headers)
longText := repeatString("This is a very long sentence that should be split across multiple chunks. ", 100)
cfg := ChunkConfig{Size: 200, Overlap: 20}
chunks := ChunkMarkdown(longText, cfg)
// The paragraph is one big block — chunking depends on paragraph splitting
// Since there are no double newlines, the whole thing is one paragraph
// The chunker should still produce at least one chunk
assert.NotEmpty(t, chunks)
for _, chunk := range chunks {
assert.NotEmpty(t, chunk.Text)
}
})
t.Run("long paragraph with line breaks produces chunks", func(t *testing.T) {
// Create long text with paragraph breaks so chunking can split
var parts []string
for i := 0; i < 50; i++ {
parts = append(parts, "This is paragraph number that contains some meaningful text for testing purposes.")
}
longText := "## Long Content\n\n" + joinParagraphs(parts)
cfg := ChunkConfig{Size: 300, Overlap: 30}
chunks := ChunkMarkdown(longText, cfg)
assert.Greater(t, len(chunks), 1, "long text should produce multiple chunks")
for _, chunk := range chunks {
assert.NotEmpty(t, chunk.Text)
assert.Equal(t, "Long Content", chunk.Section)
}
})
}
func TestChunkMarkdown_Edge_ConfigBoundaries(t *testing.T) {
t.Run("zero chunk size uses default 500", func(t *testing.T) {
text := "## Section\n\nSome content.\n"
cfg := ChunkConfig{Size: 0, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
assert.NotEmpty(t, chunks, "should still produce chunks with zero size (uses default)")
})
t.Run("negative chunk size uses default 500", func(t *testing.T) {
text := "## Section\n\nSome content.\n"
cfg := ChunkConfig{Size: -1, Overlap: 0}
chunks := ChunkMarkdown(text, cfg)
assert.NotEmpty(t, chunks)
})
t.Run("overlap equal to size resets to zero", func(t *testing.T) {
// When overlap >= size, it resets to 0
text := "## S\n\n" + repeatString("Word. ", 200)
cfg := ChunkConfig{Size: 100, Overlap: 100}
chunks := ChunkMarkdown(text, cfg)
assert.NotEmpty(t, chunks)
})
t.Run("negative overlap resets to zero", func(t *testing.T) {
text := "## S\n\n" + repeatString("Word. ", 200)
cfg := ChunkConfig{Size: 100, Overlap: -5}
chunks := ChunkMarkdown(text, cfg)
assert.NotEmpty(t, chunks)
})
}
func TestChunkMarkdown_Edge_ChunkIndexing(t *testing.T) {
t.Run("chunk indices are sequential starting from zero", func(t *testing.T) {
text := "## Section One\n\nContent one.\n\n## Section Two\n\nContent two.\n\n## Section Three\n\nContent three.\n"
chunks := ChunkMarkdown(text, DefaultChunkConfig())
for i, chunk := range chunks {
assert.Equal(t, i, chunk.Index, "chunk index should be sequential")
}
})
}
func TestChunkID_Edge_LongText(t *testing.T) {
t.Run("long text is truncated to first 100 runes for ID", func(t *testing.T) {
longText := repeatString("a", 500)
id1 := ChunkID("test.md", 0, longText)
// Same first 100 characters, different tail — should produce same ID
longText2 := repeatString("a", 100) + repeatString("b", 400)
id2 := ChunkID("test.md", 0, longText2)
assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical")
})
t.Run("unicode text uses rune count not byte count", func(t *testing.T) {
// 100 CJK characters (3 bytes each in UTF-8) = 100 runes
runeText := repeatString("\u4e16", 100)
id1 := ChunkID("test.md", 0, runeText)
// Same 100 CJK chars plus more — should produce same ID
longerText := repeatString("\u4e16", 100) + repeatString("\u754c", 50)
id2 := ChunkID("test.md", 0, longerText)
assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical (CJK)")
})
}
func TestDefaultChunkConfig(t *testing.T) {
t.Run("returns expected default values", func(t *testing.T) {
cfg := DefaultChunkConfig()
assert.Equal(t, 500, cfg.Size, "default chunk size should be 500")
assert.Equal(t, 50, cfg.Overlap, "default chunk overlap should be 50")
})
}
func TestDefaultIngestConfig(t *testing.T) {
t.Run("returns expected default values", func(t *testing.T) {
cfg := DefaultIngestConfig()
assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs")
assert.Equal(t, 100, cfg.BatchSize, "default batch size should be 100")
assert.False(t, cfg.Recreate, "recreate should be false by default")
assert.False(t, cfg.Verbose, "verbose should be false by default")
assert.Empty(t, cfg.Directory, "directory should be empty by default")
// Nested ChunkConfig should match defaults
assert.Equal(t, DefaultChunkConfig().Size, cfg.Chunk.Size)
assert.Equal(t, DefaultChunkConfig().Overlap, cfg.Chunk.Overlap)
})
}
// Helper: repeat a string n times
func repeatString(s string, n int) string {
result := ""
for i := 0; i < n; i++ {
@ -118,3 +336,15 @@ func repeatString(s string, n int) string {
}
return result
}
// Helper: join paragraphs with double newlines
func joinParagraphs(parts []string) string {
result := ""
for i, p := range parts {
if i > 0 {
result += "\n\n"
}
result += p
}
return result
}

88
ollama_test.go Normal file
View file

@ -0,0 +1,88 @@
package rag
import (
"testing"
"github.com/stretchr/testify/assert"
)
// --- DefaultOllamaConfig tests ---
func TestDefaultOllamaConfig(t *testing.T) {
t.Run("returns expected default values", func(t *testing.T) {
cfg := DefaultOllamaConfig()
assert.Equal(t, "localhost", cfg.Host, "default host should be localhost")
assert.Equal(t, 11434, cfg.Port, "default port should be 11434")
assert.Equal(t, "nomic-embed-text", cfg.Model, "default model should be nomic-embed-text")
})
}
// --- EmbedDimension tests ---
func TestEmbedDimension(t *testing.T) {
tests := []struct {
name string
model string
expected uint64
}{
{
name: "nomic-embed-text returns 768",
model: "nomic-embed-text",
expected: 768,
},
{
name: "mxbai-embed-large returns 1024",
model: "mxbai-embed-large",
expected: 1024,
},
{
name: "all-minilm returns 384",
model: "all-minilm",
expected: 384,
},
{
name: "unknown model defaults to 768",
model: "some-unknown-model",
expected: 768,
},
{
name: "empty model name defaults to 768",
model: "",
expected: 768,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Construct OllamaClient directly with the config to avoid needing a live server.
// We only set the config field; client is nil but EmbedDimension does not use it.
client := &OllamaClient{
config: OllamaConfig{Model: tc.model},
}
dim := client.EmbedDimension()
assert.Equal(t, tc.expected, dim)
})
}
}
// --- Model tests ---
func TestModel(t *testing.T) {
t.Run("returns the configured model name", func(t *testing.T) {
client := &OllamaClient{
config: OllamaConfig{Model: "nomic-embed-text"},
}
assert.Equal(t, "nomic-embed-text", client.Model())
})
t.Run("returns custom model name", func(t *testing.T) {
client := &OllamaClient{
config: OllamaConfig{Model: "custom-model"},
}
assert.Equal(t, "custom-model", client.Model())
})
}

281
qdrant_test.go Normal file
View file

@ -0,0 +1,281 @@
package rag
import (
"testing"
"github.com/qdrant/go-client/qdrant"
"github.com/stretchr/testify/assert"
)
// --- DefaultQdrantConfig tests ---
func TestDefaultQdrantConfig(t *testing.T) {
t.Run("returns expected default values", func(t *testing.T) {
cfg := DefaultQdrantConfig()
assert.Equal(t, "localhost", cfg.Host, "default host should be localhost")
assert.Equal(t, 6334, cfg.Port, "default gRPC port should be 6334")
assert.False(t, cfg.UseTLS, "TLS should be disabled by default")
assert.Empty(t, cfg.APIKey, "API key should be empty by default")
})
}
// --- pointIDToString tests ---
func TestPointIDToString(t *testing.T) {
t.Run("nil ID returns empty string", func(t *testing.T) {
result := pointIDToString(nil)
assert.Equal(t, "", result)
})
t.Run("numeric ID returns string representation", func(t *testing.T) {
id := qdrant.NewIDNum(42)
result := pointIDToString(id)
assert.Equal(t, "42", result)
})
t.Run("numeric ID zero", func(t *testing.T) {
id := qdrant.NewIDNum(0)
result := pointIDToString(id)
assert.Equal(t, "0", result)
})
t.Run("large numeric ID", func(t *testing.T) {
id := qdrant.NewIDNum(18446744073709551615) // max uint64
result := pointIDToString(id)
assert.Equal(t, "18446744073709551615", result)
})
t.Run("UUID ID returns UUID string", func(t *testing.T) {
uuid := "550e8400-e29b-41d4-a716-446655440000"
id := qdrant.NewIDUUID(uuid)
result := pointIDToString(id)
assert.Equal(t, uuid, result)
})
t.Run("empty UUID returns empty string", func(t *testing.T) {
id := qdrant.NewIDUUID("")
result := pointIDToString(id)
assert.Equal(t, "", result)
})
t.Run("string ID via NewID returns the UUID", func(t *testing.T) {
// NewID creates a UUID-type PointId
uuid := "abc-123-def"
id := qdrant.NewID(uuid)
result := pointIDToString(id)
assert.Equal(t, uuid, result)
})
}
// --- valueToGo tests ---
func TestValueToGo(t *testing.T) {
t.Run("nil value returns nil", func(t *testing.T) {
result := valueToGo(nil)
assert.Nil(t, result)
})
t.Run("string value", func(t *testing.T) {
v := qdrant.NewValueString("hello world")
result := valueToGo(v)
assert.Equal(t, "hello world", result)
})
t.Run("empty string value", func(t *testing.T) {
v := qdrant.NewValueString("")
result := valueToGo(v)
assert.Equal(t, "", result)
})
t.Run("integer value", func(t *testing.T) {
v := qdrant.NewValueInt(42)
result := valueToGo(v)
assert.Equal(t, int64(42), result)
})
t.Run("negative integer value", func(t *testing.T) {
v := qdrant.NewValueInt(-100)
result := valueToGo(v)
assert.Equal(t, int64(-100), result)
})
t.Run("zero integer value", func(t *testing.T) {
v := qdrant.NewValueInt(0)
result := valueToGo(v)
assert.Equal(t, int64(0), result)
})
t.Run("double value", func(t *testing.T) {
v := qdrant.NewValueDouble(3.14)
result := valueToGo(v)
assert.Equal(t, float64(3.14), result)
})
t.Run("negative double value", func(t *testing.T) {
v := qdrant.NewValueDouble(-2.718)
result := valueToGo(v)
assert.Equal(t, float64(-2.718), result)
})
t.Run("bool true value", func(t *testing.T) {
v := qdrant.NewValueBool(true)
result := valueToGo(v)
assert.Equal(t, true, result)
})
t.Run("bool false value", func(t *testing.T) {
v := qdrant.NewValueBool(false)
result := valueToGo(v)
assert.Equal(t, false, result)
})
t.Run("null value returns nil", func(t *testing.T) {
v := qdrant.NewValueNull()
result := valueToGo(v)
// NullValue is not handled by the switch, falls to default -> nil
assert.Nil(t, result)
})
t.Run("list value with mixed types", func(t *testing.T) {
v := qdrant.NewValueFromList(
qdrant.NewValueString("alpha"),
qdrant.NewValueInt(99),
qdrant.NewValueBool(true),
)
result := valueToGo(v)
list, ok := result.([]any)
assert.True(t, ok, "result should be a []any")
assert.Len(t, list, 3)
assert.Equal(t, "alpha", list[0])
assert.Equal(t, int64(99), list[1])
assert.Equal(t, true, list[2])
})
t.Run("empty list value", func(t *testing.T) {
v := qdrant.NewValueList(&qdrant.ListValue{Values: []*qdrant.Value{}})
result := valueToGo(v)
list, ok := result.([]any)
assert.True(t, ok, "result should be a []any")
assert.Empty(t, list)
})
t.Run("struct value with fields", func(t *testing.T) {
v := qdrant.NewValueStruct(&qdrant.Struct{
Fields: map[string]*qdrant.Value{
"name": qdrant.NewValueString("test"),
"age": qdrant.NewValueInt(25),
},
})
result := valueToGo(v)
m, ok := result.(map[string]any)
assert.True(t, ok, "result should be a map[string]any")
assert.Equal(t, "test", m["name"])
assert.Equal(t, int64(25), m["age"])
})
t.Run("empty struct value", func(t *testing.T) {
v := qdrant.NewValueStruct(&qdrant.Struct{
Fields: map[string]*qdrant.Value{},
})
result := valueToGo(v)
m, ok := result.(map[string]any)
assert.True(t, ok, "result should be a map[string]any")
assert.Empty(t, m)
})
t.Run("nested list within struct", func(t *testing.T) {
v := qdrant.NewValueStruct(&qdrant.Struct{
Fields: map[string]*qdrant.Value{
"tags": qdrant.NewValueFromList(
qdrant.NewValueString("go"),
qdrant.NewValueString("rag"),
),
},
})
result := valueToGo(v)
m, ok := result.(map[string]any)
assert.True(t, ok, "result should be a map[string]any")
tags, ok := m["tags"].([]any)
assert.True(t, ok, "tags should be a []any")
assert.Equal(t, []any{"go", "rag"}, tags)
})
t.Run("nested struct within struct", func(t *testing.T) {
inner := qdrant.NewValueStruct(&qdrant.Struct{
Fields: map[string]*qdrant.Value{
"key": qdrant.NewValueString("value"),
},
})
v := qdrant.NewValueStruct(&qdrant.Struct{
Fields: map[string]*qdrant.Value{
"nested": inner,
},
})
result := valueToGo(v)
m, ok := result.(map[string]any)
assert.True(t, ok)
nested, ok := m["nested"].(map[string]any)
assert.True(t, ok)
assert.Equal(t, "value", nested["key"])
})
}
// --- Point struct tests ---
func TestPointStruct(t *testing.T) {
t.Run("Point holds ID vector and payload", func(t *testing.T) {
p := Point{
ID: "test-id-123",
Vector: []float32{0.1, 0.2, 0.3},
Payload: map[string]any{
"text": "hello",
"source": "test.md",
},
}
assert.Equal(t, "test-id-123", p.ID)
assert.Len(t, p.Vector, 3)
assert.Equal(t, float32(0.1), p.Vector[0])
assert.Equal(t, "hello", p.Payload["text"])
assert.Equal(t, "test.md", p.Payload["source"])
})
t.Run("Point with empty payload", func(t *testing.T) {
p := Point{
ID: "empty",
Vector: []float32{},
Payload: map[string]any{},
}
assert.Equal(t, "empty", p.ID)
assert.Empty(t, p.Vector)
assert.Empty(t, p.Payload)
})
}
// --- SearchResult struct tests ---
func TestSearchResultStruct(t *testing.T) {
t.Run("SearchResult holds all fields", func(t *testing.T) {
sr := SearchResult{
ID: "result-1",
Score: 0.95,
Payload: map[string]any{
"text": "some text",
"category": "docs",
},
}
assert.Equal(t, "result-1", sr.ID)
assert.Equal(t, float32(0.95), sr.Score)
assert.Equal(t, "some text", sr.Payload["text"])
})
}

274
query_test.go Normal file
View file

@ -0,0 +1,274 @@
package rag
import (
"encoding/json"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// --- DefaultQueryConfig tests ---
func TestDefaultQueryConfig(t *testing.T) {
t.Run("returns expected default values", func(t *testing.T) {
cfg := DefaultQueryConfig()
assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs")
assert.Equal(t, uint64(5), cfg.Limit, "default limit should be 5")
assert.Equal(t, float32(0.5), cfg.Threshold, "default threshold should be 0.5")
assert.Empty(t, cfg.Category, "default category should be empty")
})
}
// --- FormatResultsText tests ---
func TestFormatResultsText(t *testing.T) {
t.Run("empty results returns no-results message", func(t *testing.T) {
result := FormatResultsText(nil)
assert.Equal(t, "No results found.", result)
})
t.Run("empty slice returns no-results message", func(t *testing.T) {
result := FormatResultsText([]QueryResult{})
assert.Equal(t, "No results found.", result)
})
t.Run("single result with all fields", func(t *testing.T) {
results := []QueryResult{
{
Text: "Some relevant text about Go.",
Source: "docs/go-intro.md",
Section: "Introduction",
Category: "documentation",
Score: 0.95,
},
}
output := FormatResultsText(results)
assert.Contains(t, output, "Result 1")
assert.Contains(t, output, "score: 0.95")
assert.Contains(t, output, "Source: docs/go-intro.md")
assert.Contains(t, output, "Section: Introduction")
assert.Contains(t, output, "Category: documentation")
assert.Contains(t, output, "Some relevant text about Go.")
})
t.Run("section omitted when empty", func(t *testing.T) {
results := []QueryResult{
{
Text: "No section here.",
Source: "test.md",
Section: "",
Category: "docs",
Score: 0.80,
},
}
output := FormatResultsText(results)
assert.NotContains(t, output, "Section:")
})
t.Run("multiple results numbered correctly", func(t *testing.T) {
results := []QueryResult{
{Text: "First result.", Source: "a.md", Category: "docs", Score: 0.90},
{Text: "Second result.", Source: "b.md", Category: "docs", Score: 0.85},
{Text: "Third result.", Source: "c.md", Category: "docs", Score: 0.80},
}
output := FormatResultsText(results)
assert.Contains(t, output, "Result 1")
assert.Contains(t, output, "Result 2")
assert.Contains(t, output, "Result 3")
// Verify ordering: first result appears before second
idx1 := strings.Index(output, "Result 1")
idx2 := strings.Index(output, "Result 2")
idx3 := strings.Index(output, "Result 3")
assert.Less(t, idx1, idx2)
assert.Less(t, idx2, idx3)
})
t.Run("score formatted to two decimal places", func(t *testing.T) {
results := []QueryResult{
{Text: "Test.", Source: "s.md", Category: "c", Score: 0.123456},
}
output := FormatResultsText(results)
assert.Contains(t, output, "score: 0.12")
})
}
// --- FormatResultsContext tests ---
func TestFormatResultsContext(t *testing.T) {
t.Run("empty results returns empty string", func(t *testing.T) {
result := FormatResultsContext(nil)
assert.Equal(t, "", result)
})
t.Run("empty slice returns empty string", func(t *testing.T) {
result := FormatResultsContext([]QueryResult{})
assert.Equal(t, "", result)
})
t.Run("wraps output in retrieved_context tags", func(t *testing.T) {
results := []QueryResult{
{Text: "Hello world.", Source: "test.md", Section: "Intro", Category: "docs", Score: 0.9},
}
output := FormatResultsContext(results)
assert.True(t, strings.HasPrefix(output, "<retrieved_context>\n"),
"output should start with <retrieved_context> tag")
assert.True(t, strings.HasSuffix(output, "</retrieved_context>"),
"output should end with </retrieved_context> tag")
})
t.Run("wraps each result in document tags with attributes", func(t *testing.T) {
results := []QueryResult{
{
Text: "Content here.",
Source: "file.md",
Section: "My Section",
Category: "documentation",
Score: 0.88,
},
}
output := FormatResultsContext(results)
assert.Contains(t, output, `source="file.md"`)
assert.Contains(t, output, `section="My Section"`)
assert.Contains(t, output, `category="documentation"`)
assert.Contains(t, output, "Content here.")
assert.Contains(t, output, "</document>")
})
t.Run("escapes XML special characters in attributes and text", func(t *testing.T) {
results := []QueryResult{
{
Text: `Text with <tags> & "quotes" in it.`,
Source: `path/with<special>&chars.md`,
Section: `Section "One"`,
Category: "docs",
Score: 0.75,
},
}
output := FormatResultsContext(results)
// The html.EscapeString function escapes <, >, &, " and '
assert.Contains(t, output, "&lt;tags&gt;")
assert.Contains(t, output, "&amp;")
assert.Contains(t, output, "&#34;quotes&#34;")
// Source attribute should also be escaped
assert.Contains(t, output, "path/with&lt;special&gt;&amp;chars.md")
})
t.Run("multiple results each wrapped in document tags", func(t *testing.T) {
results := []QueryResult{
{Text: "First.", Source: "a.md", Section: "", Category: "docs", Score: 0.9},
{Text: "Second.", Source: "b.md", Section: "", Category: "docs", Score: 0.8},
}
output := FormatResultsContext(results)
// Count document tags
assert.Equal(t, 2, strings.Count(output, "<document "))
assert.Equal(t, 2, strings.Count(output, "</document>"))
})
}
// --- FormatResultsJSON tests ---
func TestFormatResultsJSON(t *testing.T) {
t.Run("empty results returns empty JSON array", func(t *testing.T) {
result := FormatResultsJSON(nil)
assert.Equal(t, "[]", result)
})
t.Run("empty slice returns empty JSON array", func(t *testing.T) {
result := FormatResultsJSON([]QueryResult{})
assert.Equal(t, "[]", result)
})
t.Run("single result produces valid JSON", func(t *testing.T) {
results := []QueryResult{
{
Text: "Test content.",
Source: "test.md",
Section: "Intro",
Category: "docs",
Score: 0.9234,
},
}
output := FormatResultsJSON(results)
// Verify it parses as valid JSON
var parsed []map[string]any
err := json.Unmarshal([]byte(output), &parsed)
require.NoError(t, err, "output should be valid JSON")
require.Len(t, parsed, 1)
assert.Equal(t, "test.md", parsed[0]["source"])
assert.Equal(t, "Intro", parsed[0]["section"])
assert.Equal(t, "docs", parsed[0]["category"])
assert.Equal(t, "Test content.", parsed[0]["text"])
// Score is formatted to 4 decimal places
assert.InDelta(t, 0.9234, parsed[0]["score"], 0.0001)
})
t.Run("multiple results produce valid JSON array", func(t *testing.T) {
results := []QueryResult{
{Text: "First.", Source: "a.md", Section: "A", Category: "docs", Score: 0.95},
{Text: "Second.", Source: "b.md", Section: "B", Category: "code", Score: 0.80},
{Text: "Third.", Source: "c.md", Section: "C", Category: "task", Score: 0.70},
}
output := FormatResultsJSON(results)
var parsed []map[string]any
err := json.Unmarshal([]byte(output), &parsed)
require.NoError(t, err, "output should be valid JSON")
require.Len(t, parsed, 3)
assert.Equal(t, "First.", parsed[0]["text"])
assert.Equal(t, "Second.", parsed[1]["text"])
assert.Equal(t, "Third.", parsed[2]["text"])
})
t.Run("special characters are JSON-escaped in text", func(t *testing.T) {
results := []QueryResult{
{
Text: "Line one\nLine two\twith tab and \"quotes\"",
Source: "special.md",
Section: "",
Category: "docs",
Score: 0.5,
},
}
output := FormatResultsJSON(results)
var parsed []map[string]any
err := json.Unmarshal([]byte(output), &parsed)
require.NoError(t, err, "output should be valid JSON even with special characters")
assert.Equal(t, "Line one\nLine two\twith tab and \"quotes\"", parsed[0]["text"])
})
t.Run("score formatted to four decimal places", func(t *testing.T) {
results := []QueryResult{
{Text: "T.", Source: "s.md", Section: "", Category: "c", Score: 0.123456789},
}
output := FormatResultsJSON(results)
assert.Contains(t, output, "0.1235")
})
}