diff --git a/chunk_test.go b/chunk_test.go index 87fd5c0..6bd2c4a 100644 --- a/chunk_test.go +++ b/chunk_test.go @@ -110,7 +110,225 @@ func TestShouldProcess_Good_MarkdownFiles(t *testing.T) { assert.False(t, ShouldProcess("doc")) } -// Helper function +// --- Additional chunk edge cases --- + +func TestChunkMarkdown_Edge_EmptyInput(t *testing.T) { + t.Run("empty string returns no chunks", func(t *testing.T) { + chunks := ChunkMarkdown("", DefaultChunkConfig()) + assert.Empty(t, chunks) + }) + + t.Run("whitespace only returns no chunks", func(t *testing.T) { + chunks := ChunkMarkdown(" \n\n \t \n ", DefaultChunkConfig()) + assert.Empty(t, chunks) + }) + + t.Run("single newline returns no chunks", func(t *testing.T) { + chunks := ChunkMarkdown("\n", DefaultChunkConfig()) + assert.Empty(t, chunks) + }) +} + +func TestChunkMarkdown_Edge_OnlyHeadersNoContent(t *testing.T) { + t.Run("single header with no body", func(t *testing.T) { + text := "## Just a Header\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Equal(t, "Just a Header", chunks[0].Section) + assert.Contains(t, chunks[0].Text, "Just a Header") + }) + + t.Run("multiple headers with no body content", func(t *testing.T) { + text := "## Header One\n\n## Header Two\n\n## Header Three\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + // Each header becomes its own section + assert.GreaterOrEqual(t, len(chunks), 2, "should produce at least two chunks for separate sections") + }) + + t.Run("header hierarchy with minimal content", func(t *testing.T) { + text := "# Top Level\n\n## Sub Section\n\n### Sub Sub\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.NotEmpty(t, chunks, "should produce at least one chunk") + }) +} + +func TestChunkMarkdown_Edge_UnicodeAndEmoji(t *testing.T) { + t.Run("unicode text chunked correctly", func(t *testing.T) { + text := "## Unicode Section\n\nThis section has unicode: \u00e9\u00e0\u00fc\u00f1\u00f6\u00e4\u00df \u4e16\u754c \u041f\u0440\u0438\u0432\u0435\u0442 \u0645\u0631\u062d\u0628\u0627\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.Len(t, chunks, 1) + assert.Contains(t, chunks[0].Text, "\u00e9\u00e0\u00fc") + assert.Contains(t, chunks[0].Text, "\u4e16\u754c") + assert.Equal(t, "Unicode Section", chunks[0].Section) + }) + + t.Run("emoji text chunked correctly", func(t *testing.T) { + text := "## Emoji Section\n\nHello world! \U0001f600\U0001f680\U0001f30d\U0001f4da\n\nMore text with \u2764\ufe0f and \U0001f525 emojis.\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + assert.NotEmpty(t, chunks) + assert.Contains(t, chunks[0].Text, "\U0001f600") + assert.Equal(t, "Emoji Section", chunks[0].Section) + }) + + t.Run("rune-safe overlap with multibyte characters", func(t *testing.T) { + // Create text with multibyte characters that exceeds chunk size + // Each CJK character is 3 bytes in UTF-8 but 1 rune + para1 := "\u6d4b\u8bd5" + repeatString("\u4e16\u754c", 100) // ~200+ runes of CJK + para2 := "\u8fd4\u56de" + repeatString("\u4f60\u597d", 100) // ~200+ runes of CJK + text := "## CJK\n\n" + para1 + "\n\n" + para2 + "\n" + + cfg := ChunkConfig{Size: 150, Overlap: 30} + chunks := ChunkMarkdown(text, cfg) + + // Should not panic or produce corrupt text + assert.NotEmpty(t, chunks) + for _, chunk := range chunks { + assert.NotEmpty(t, chunk.Text) + // Verify no partial rune corruption by round-tripping through []rune + runes := []rune(chunk.Text) + assert.Equal(t, chunk.Text, string(runes), "text should survive rune round-trip without corruption") + } + }) +} + +func TestChunkMarkdown_Edge_VeryLongSingleParagraph(t *testing.T) { + t.Run("long paragraph without headers splits into multiple chunks", func(t *testing.T) { + // Create a very long single paragraph (no section headers) + longText := repeatString("This is a very long sentence that should be split across multiple chunks. ", 100) + + cfg := ChunkConfig{Size: 200, Overlap: 20} + chunks := ChunkMarkdown(longText, cfg) + + // The paragraph is one big block — chunking depends on paragraph splitting + // Since there are no double newlines, the whole thing is one paragraph + // The chunker should still produce at least one chunk + assert.NotEmpty(t, chunks) + for _, chunk := range chunks { + assert.NotEmpty(t, chunk.Text) + } + }) + + t.Run("long paragraph with line breaks produces chunks", func(t *testing.T) { + // Create long text with paragraph breaks so chunking can split + var parts []string + for i := 0; i < 50; i++ { + parts = append(parts, "This is paragraph number that contains some meaningful text for testing purposes.") + } + longText := "## Long Content\n\n" + joinParagraphs(parts) + + cfg := ChunkConfig{Size: 300, Overlap: 30} + chunks := ChunkMarkdown(longText, cfg) + + assert.Greater(t, len(chunks), 1, "long text should produce multiple chunks") + for _, chunk := range chunks { + assert.NotEmpty(t, chunk.Text) + assert.Equal(t, "Long Content", chunk.Section) + } + }) +} + +func TestChunkMarkdown_Edge_ConfigBoundaries(t *testing.T) { + t.Run("zero chunk size uses default 500", func(t *testing.T) { + text := "## Section\n\nSome content.\n" + cfg := ChunkConfig{Size: 0, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + assert.NotEmpty(t, chunks, "should still produce chunks with zero size (uses default)") + }) + + t.Run("negative chunk size uses default 500", func(t *testing.T) { + text := "## Section\n\nSome content.\n" + cfg := ChunkConfig{Size: -1, Overlap: 0} + chunks := ChunkMarkdown(text, cfg) + + assert.NotEmpty(t, chunks) + }) + + t.Run("overlap equal to size resets to zero", func(t *testing.T) { + // When overlap >= size, it resets to 0 + text := "## S\n\n" + repeatString("Word. ", 200) + cfg := ChunkConfig{Size: 100, Overlap: 100} + chunks := ChunkMarkdown(text, cfg) + + assert.NotEmpty(t, chunks) + }) + + t.Run("negative overlap resets to zero", func(t *testing.T) { + text := "## S\n\n" + repeatString("Word. ", 200) + cfg := ChunkConfig{Size: 100, Overlap: -5} + chunks := ChunkMarkdown(text, cfg) + + assert.NotEmpty(t, chunks) + }) +} + +func TestChunkMarkdown_Edge_ChunkIndexing(t *testing.T) { + t.Run("chunk indices are sequential starting from zero", func(t *testing.T) { + text := "## Section One\n\nContent one.\n\n## Section Two\n\nContent two.\n\n## Section Three\n\nContent three.\n" + chunks := ChunkMarkdown(text, DefaultChunkConfig()) + + for i, chunk := range chunks { + assert.Equal(t, i, chunk.Index, "chunk index should be sequential") + } + }) +} + +func TestChunkID_Edge_LongText(t *testing.T) { + t.Run("long text is truncated to first 100 runes for ID", func(t *testing.T) { + longText := repeatString("a", 500) + id1 := ChunkID("test.md", 0, longText) + + // Same first 100 characters, different tail — should produce same ID + longText2 := repeatString("a", 100) + repeatString("b", 400) + id2 := ChunkID("test.md", 0, longText2) + + assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical") + }) + + t.Run("unicode text uses rune count not byte count", func(t *testing.T) { + // 100 CJK characters (3 bytes each in UTF-8) = 100 runes + runeText := repeatString("\u4e16", 100) + id1 := ChunkID("test.md", 0, runeText) + + // Same 100 CJK chars plus more — should produce same ID + longerText := repeatString("\u4e16", 100) + repeatString("\u754c", 50) + id2 := ChunkID("test.md", 0, longerText) + + assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical (CJK)") + }) +} + +func TestDefaultChunkConfig(t *testing.T) { + t.Run("returns expected default values", func(t *testing.T) { + cfg := DefaultChunkConfig() + + assert.Equal(t, 500, cfg.Size, "default chunk size should be 500") + assert.Equal(t, 50, cfg.Overlap, "default chunk overlap should be 50") + }) +} + +func TestDefaultIngestConfig(t *testing.T) { + t.Run("returns expected default values", func(t *testing.T) { + cfg := DefaultIngestConfig() + + assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs") + assert.Equal(t, 100, cfg.BatchSize, "default batch size should be 100") + assert.False(t, cfg.Recreate, "recreate should be false by default") + assert.False(t, cfg.Verbose, "verbose should be false by default") + assert.Empty(t, cfg.Directory, "directory should be empty by default") + + // Nested ChunkConfig should match defaults + assert.Equal(t, DefaultChunkConfig().Size, cfg.Chunk.Size) + assert.Equal(t, DefaultChunkConfig().Overlap, cfg.Chunk.Overlap) + }) +} + +// Helper: repeat a string n times func repeatString(s string, n int) string { result := "" for i := 0; i < n; i++ { @@ -118,3 +336,15 @@ func repeatString(s string, n int) string { } return result } + +// Helper: join paragraphs with double newlines +func joinParagraphs(parts []string) string { + result := "" + for i, p := range parts { + if i > 0 { + result += "\n\n" + } + result += p + } + return result +} diff --git a/ollama_test.go b/ollama_test.go new file mode 100644 index 0000000..c5b4b9e --- /dev/null +++ b/ollama_test.go @@ -0,0 +1,88 @@ +package rag + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// --- DefaultOllamaConfig tests --- + +func TestDefaultOllamaConfig(t *testing.T) { + t.Run("returns expected default values", func(t *testing.T) { + cfg := DefaultOllamaConfig() + + assert.Equal(t, "localhost", cfg.Host, "default host should be localhost") + assert.Equal(t, 11434, cfg.Port, "default port should be 11434") + assert.Equal(t, "nomic-embed-text", cfg.Model, "default model should be nomic-embed-text") + }) +} + +// --- EmbedDimension tests --- + +func TestEmbedDimension(t *testing.T) { + tests := []struct { + name string + model string + expected uint64 + }{ + { + name: "nomic-embed-text returns 768", + model: "nomic-embed-text", + expected: 768, + }, + { + name: "mxbai-embed-large returns 1024", + model: "mxbai-embed-large", + expected: 1024, + }, + { + name: "all-minilm returns 384", + model: "all-minilm", + expected: 384, + }, + { + name: "unknown model defaults to 768", + model: "some-unknown-model", + expected: 768, + }, + { + name: "empty model name defaults to 768", + model: "", + expected: 768, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Construct OllamaClient directly with the config to avoid needing a live server. + // We only set the config field; client is nil but EmbedDimension does not use it. + client := &OllamaClient{ + config: OllamaConfig{Model: tc.model}, + } + + dim := client.EmbedDimension() + assert.Equal(t, tc.expected, dim) + }) + } +} + +// --- Model tests --- + +func TestModel(t *testing.T) { + t.Run("returns the configured model name", func(t *testing.T) { + client := &OllamaClient{ + config: OllamaConfig{Model: "nomic-embed-text"}, + } + + assert.Equal(t, "nomic-embed-text", client.Model()) + }) + + t.Run("returns custom model name", func(t *testing.T) { + client := &OllamaClient{ + config: OllamaConfig{Model: "custom-model"}, + } + + assert.Equal(t, "custom-model", client.Model()) + }) +} diff --git a/qdrant_test.go b/qdrant_test.go new file mode 100644 index 0000000..f8dffc5 --- /dev/null +++ b/qdrant_test.go @@ -0,0 +1,281 @@ +package rag + +import ( + "testing" + + "github.com/qdrant/go-client/qdrant" + "github.com/stretchr/testify/assert" +) + +// --- DefaultQdrantConfig tests --- + +func TestDefaultQdrantConfig(t *testing.T) { + t.Run("returns expected default values", func(t *testing.T) { + cfg := DefaultQdrantConfig() + + assert.Equal(t, "localhost", cfg.Host, "default host should be localhost") + assert.Equal(t, 6334, cfg.Port, "default gRPC port should be 6334") + assert.False(t, cfg.UseTLS, "TLS should be disabled by default") + assert.Empty(t, cfg.APIKey, "API key should be empty by default") + }) +} + +// --- pointIDToString tests --- + +func TestPointIDToString(t *testing.T) { + t.Run("nil ID returns empty string", func(t *testing.T) { + result := pointIDToString(nil) + assert.Equal(t, "", result) + }) + + t.Run("numeric ID returns string representation", func(t *testing.T) { + id := qdrant.NewIDNum(42) + result := pointIDToString(id) + assert.Equal(t, "42", result) + }) + + t.Run("numeric ID zero", func(t *testing.T) { + id := qdrant.NewIDNum(0) + result := pointIDToString(id) + assert.Equal(t, "0", result) + }) + + t.Run("large numeric ID", func(t *testing.T) { + id := qdrant.NewIDNum(18446744073709551615) // max uint64 + result := pointIDToString(id) + assert.Equal(t, "18446744073709551615", result) + }) + + t.Run("UUID ID returns UUID string", func(t *testing.T) { + uuid := "550e8400-e29b-41d4-a716-446655440000" + id := qdrant.NewIDUUID(uuid) + result := pointIDToString(id) + assert.Equal(t, uuid, result) + }) + + t.Run("empty UUID returns empty string", func(t *testing.T) { + id := qdrant.NewIDUUID("") + result := pointIDToString(id) + assert.Equal(t, "", result) + }) + + t.Run("string ID via NewID returns the UUID", func(t *testing.T) { + // NewID creates a UUID-type PointId + uuid := "abc-123-def" + id := qdrant.NewID(uuid) + result := pointIDToString(id) + assert.Equal(t, uuid, result) + }) +} + +// --- valueToGo tests --- + +func TestValueToGo(t *testing.T) { + t.Run("nil value returns nil", func(t *testing.T) { + result := valueToGo(nil) + assert.Nil(t, result) + }) + + t.Run("string value", func(t *testing.T) { + v := qdrant.NewValueString("hello world") + result := valueToGo(v) + assert.Equal(t, "hello world", result) + }) + + t.Run("empty string value", func(t *testing.T) { + v := qdrant.NewValueString("") + result := valueToGo(v) + assert.Equal(t, "", result) + }) + + t.Run("integer value", func(t *testing.T) { + v := qdrant.NewValueInt(42) + result := valueToGo(v) + assert.Equal(t, int64(42), result) + }) + + t.Run("negative integer value", func(t *testing.T) { + v := qdrant.NewValueInt(-100) + result := valueToGo(v) + assert.Equal(t, int64(-100), result) + }) + + t.Run("zero integer value", func(t *testing.T) { + v := qdrant.NewValueInt(0) + result := valueToGo(v) + assert.Equal(t, int64(0), result) + }) + + t.Run("double value", func(t *testing.T) { + v := qdrant.NewValueDouble(3.14) + result := valueToGo(v) + assert.Equal(t, float64(3.14), result) + }) + + t.Run("negative double value", func(t *testing.T) { + v := qdrant.NewValueDouble(-2.718) + result := valueToGo(v) + assert.Equal(t, float64(-2.718), result) + }) + + t.Run("bool true value", func(t *testing.T) { + v := qdrant.NewValueBool(true) + result := valueToGo(v) + assert.Equal(t, true, result) + }) + + t.Run("bool false value", func(t *testing.T) { + v := qdrant.NewValueBool(false) + result := valueToGo(v) + assert.Equal(t, false, result) + }) + + t.Run("null value returns nil", func(t *testing.T) { + v := qdrant.NewValueNull() + result := valueToGo(v) + // NullValue is not handled by the switch, falls to default -> nil + assert.Nil(t, result) + }) + + t.Run("list value with mixed types", func(t *testing.T) { + v := qdrant.NewValueFromList( + qdrant.NewValueString("alpha"), + qdrant.NewValueInt(99), + qdrant.NewValueBool(true), + ) + result := valueToGo(v) + + list, ok := result.([]any) + assert.True(t, ok, "result should be a []any") + assert.Len(t, list, 3) + assert.Equal(t, "alpha", list[0]) + assert.Equal(t, int64(99), list[1]) + assert.Equal(t, true, list[2]) + }) + + t.Run("empty list value", func(t *testing.T) { + v := qdrant.NewValueList(&qdrant.ListValue{Values: []*qdrant.Value{}}) + result := valueToGo(v) + + list, ok := result.([]any) + assert.True(t, ok, "result should be a []any") + assert.Empty(t, list) + }) + + t.Run("struct value with fields", func(t *testing.T) { + v := qdrant.NewValueStruct(&qdrant.Struct{ + Fields: map[string]*qdrant.Value{ + "name": qdrant.NewValueString("test"), + "age": qdrant.NewValueInt(25), + }, + }) + result := valueToGo(v) + + m, ok := result.(map[string]any) + assert.True(t, ok, "result should be a map[string]any") + assert.Equal(t, "test", m["name"]) + assert.Equal(t, int64(25), m["age"]) + }) + + t.Run("empty struct value", func(t *testing.T) { + v := qdrant.NewValueStruct(&qdrant.Struct{ + Fields: map[string]*qdrant.Value{}, + }) + result := valueToGo(v) + + m, ok := result.(map[string]any) + assert.True(t, ok, "result should be a map[string]any") + assert.Empty(t, m) + }) + + t.Run("nested list within struct", func(t *testing.T) { + v := qdrant.NewValueStruct(&qdrant.Struct{ + Fields: map[string]*qdrant.Value{ + "tags": qdrant.NewValueFromList( + qdrant.NewValueString("go"), + qdrant.NewValueString("rag"), + ), + }, + }) + result := valueToGo(v) + + m, ok := result.(map[string]any) + assert.True(t, ok, "result should be a map[string]any") + + tags, ok := m["tags"].([]any) + assert.True(t, ok, "tags should be a []any") + assert.Equal(t, []any{"go", "rag"}, tags) + }) + + t.Run("nested struct within struct", func(t *testing.T) { + inner := qdrant.NewValueStruct(&qdrant.Struct{ + Fields: map[string]*qdrant.Value{ + "key": qdrant.NewValueString("value"), + }, + }) + v := qdrant.NewValueStruct(&qdrant.Struct{ + Fields: map[string]*qdrant.Value{ + "nested": inner, + }, + }) + result := valueToGo(v) + + m, ok := result.(map[string]any) + assert.True(t, ok) + nested, ok := m["nested"].(map[string]any) + assert.True(t, ok) + assert.Equal(t, "value", nested["key"]) + }) +} + +// --- Point struct tests --- + +func TestPointStruct(t *testing.T) { + t.Run("Point holds ID vector and payload", func(t *testing.T) { + p := Point{ + ID: "test-id-123", + Vector: []float32{0.1, 0.2, 0.3}, + Payload: map[string]any{ + "text": "hello", + "source": "test.md", + }, + } + + assert.Equal(t, "test-id-123", p.ID) + assert.Len(t, p.Vector, 3) + assert.Equal(t, float32(0.1), p.Vector[0]) + assert.Equal(t, "hello", p.Payload["text"]) + assert.Equal(t, "test.md", p.Payload["source"]) + }) + + t.Run("Point with empty payload", func(t *testing.T) { + p := Point{ + ID: "empty", + Vector: []float32{}, + Payload: map[string]any{}, + } + + assert.Equal(t, "empty", p.ID) + assert.Empty(t, p.Vector) + assert.Empty(t, p.Payload) + }) +} + +// --- SearchResult struct tests --- + +func TestSearchResultStruct(t *testing.T) { + t.Run("SearchResult holds all fields", func(t *testing.T) { + sr := SearchResult{ + ID: "result-1", + Score: 0.95, + Payload: map[string]any{ + "text": "some text", + "category": "docs", + }, + } + + assert.Equal(t, "result-1", sr.ID) + assert.Equal(t, float32(0.95), sr.Score) + assert.Equal(t, "some text", sr.Payload["text"]) + }) +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..fbbc7f8 --- /dev/null +++ b/query_test.go @@ -0,0 +1,274 @@ +package rag + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- DefaultQueryConfig tests --- + +func TestDefaultQueryConfig(t *testing.T) { + t.Run("returns expected default values", func(t *testing.T) { + cfg := DefaultQueryConfig() + + assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs") + assert.Equal(t, uint64(5), cfg.Limit, "default limit should be 5") + assert.Equal(t, float32(0.5), cfg.Threshold, "default threshold should be 0.5") + assert.Empty(t, cfg.Category, "default category should be empty") + }) +} + +// --- FormatResultsText tests --- + +func TestFormatResultsText(t *testing.T) { + t.Run("empty results returns no-results message", func(t *testing.T) { + result := FormatResultsText(nil) + assert.Equal(t, "No results found.", result) + }) + + t.Run("empty slice returns no-results message", func(t *testing.T) { + result := FormatResultsText([]QueryResult{}) + assert.Equal(t, "No results found.", result) + }) + + t.Run("single result with all fields", func(t *testing.T) { + results := []QueryResult{ + { + Text: "Some relevant text about Go.", + Source: "docs/go-intro.md", + Section: "Introduction", + Category: "documentation", + Score: 0.95, + }, + } + + output := FormatResultsText(results) + + assert.Contains(t, output, "Result 1") + assert.Contains(t, output, "score: 0.95") + assert.Contains(t, output, "Source: docs/go-intro.md") + assert.Contains(t, output, "Section: Introduction") + assert.Contains(t, output, "Category: documentation") + assert.Contains(t, output, "Some relevant text about Go.") + }) + + t.Run("section omitted when empty", func(t *testing.T) { + results := []QueryResult{ + { + Text: "No section here.", + Source: "test.md", + Section: "", + Category: "docs", + Score: 0.80, + }, + } + + output := FormatResultsText(results) + + assert.NotContains(t, output, "Section:") + }) + + t.Run("multiple results numbered correctly", func(t *testing.T) { + results := []QueryResult{ + {Text: "First result.", Source: "a.md", Category: "docs", Score: 0.90}, + {Text: "Second result.", Source: "b.md", Category: "docs", Score: 0.85}, + {Text: "Third result.", Source: "c.md", Category: "docs", Score: 0.80}, + } + + output := FormatResultsText(results) + + assert.Contains(t, output, "Result 1") + assert.Contains(t, output, "Result 2") + assert.Contains(t, output, "Result 3") + // Verify ordering: first result appears before second + idx1 := strings.Index(output, "Result 1") + idx2 := strings.Index(output, "Result 2") + idx3 := strings.Index(output, "Result 3") + assert.Less(t, idx1, idx2) + assert.Less(t, idx2, idx3) + }) + + t.Run("score formatted to two decimal places", func(t *testing.T) { + results := []QueryResult{ + {Text: "Test.", Source: "s.md", Category: "c", Score: 0.123456}, + } + + output := FormatResultsText(results) + + assert.Contains(t, output, "score: 0.12") + }) +} + +// --- FormatResultsContext tests --- + +func TestFormatResultsContext(t *testing.T) { + t.Run("empty results returns empty string", func(t *testing.T) { + result := FormatResultsContext(nil) + assert.Equal(t, "", result) + }) + + t.Run("empty slice returns empty string", func(t *testing.T) { + result := FormatResultsContext([]QueryResult{}) + assert.Equal(t, "", result) + }) + + t.Run("wraps output in retrieved_context tags", func(t *testing.T) { + results := []QueryResult{ + {Text: "Hello world.", Source: "test.md", Section: "Intro", Category: "docs", Score: 0.9}, + } + + output := FormatResultsContext(results) + + assert.True(t, strings.HasPrefix(output, "\n"), + "output should start with tag") + assert.True(t, strings.HasSuffix(output, ""), + "output should end with tag") + }) + + t.Run("wraps each result in document tags with attributes", func(t *testing.T) { + results := []QueryResult{ + { + Text: "Content here.", + Source: "file.md", + Section: "My Section", + Category: "documentation", + Score: 0.88, + }, + } + + output := FormatResultsContext(results) + + assert.Contains(t, output, `source="file.md"`) + assert.Contains(t, output, `section="My Section"`) + assert.Contains(t, output, `category="documentation"`) + assert.Contains(t, output, "Content here.") + assert.Contains(t, output, "") + }) + + t.Run("escapes XML special characters in attributes and text", func(t *testing.T) { + results := []QueryResult{ + { + Text: `Text with & "quotes" in it.`, + Source: `path/with&chars.md`, + Section: `Section "One"`, + Category: "docs", + Score: 0.75, + }, + } + + output := FormatResultsContext(results) + + // The html.EscapeString function escapes <, >, &, " and ' + assert.Contains(t, output, "<tags>") + assert.Contains(t, output, "&") + assert.Contains(t, output, ""quotes"") + // Source attribute should also be escaped + assert.Contains(t, output, "path/with<special>&chars.md") + }) + + t.Run("multiple results each wrapped in document tags", func(t *testing.T) { + results := []QueryResult{ + {Text: "First.", Source: "a.md", Section: "", Category: "docs", Score: 0.9}, + {Text: "Second.", Source: "b.md", Section: "", Category: "docs", Score: 0.8}, + } + + output := FormatResultsContext(results) + + // Count document tags + assert.Equal(t, 2, strings.Count(output, "")) + }) +} + +// --- FormatResultsJSON tests --- + +func TestFormatResultsJSON(t *testing.T) { + t.Run("empty results returns empty JSON array", func(t *testing.T) { + result := FormatResultsJSON(nil) + assert.Equal(t, "[]", result) + }) + + t.Run("empty slice returns empty JSON array", func(t *testing.T) { + result := FormatResultsJSON([]QueryResult{}) + assert.Equal(t, "[]", result) + }) + + t.Run("single result produces valid JSON", func(t *testing.T) { + results := []QueryResult{ + { + Text: "Test content.", + Source: "test.md", + Section: "Intro", + Category: "docs", + Score: 0.9234, + }, + } + + output := FormatResultsJSON(results) + + // Verify it parses as valid JSON + var parsed []map[string]any + err := json.Unmarshal([]byte(output), &parsed) + require.NoError(t, err, "output should be valid JSON") + require.Len(t, parsed, 1) + + assert.Equal(t, "test.md", parsed[0]["source"]) + assert.Equal(t, "Intro", parsed[0]["section"]) + assert.Equal(t, "docs", parsed[0]["category"]) + assert.Equal(t, "Test content.", parsed[0]["text"]) + // Score is formatted to 4 decimal places + assert.InDelta(t, 0.9234, parsed[0]["score"], 0.0001) + }) + + t.Run("multiple results produce valid JSON array", func(t *testing.T) { + results := []QueryResult{ + {Text: "First.", Source: "a.md", Section: "A", Category: "docs", Score: 0.95}, + {Text: "Second.", Source: "b.md", Section: "B", Category: "code", Score: 0.80}, + {Text: "Third.", Source: "c.md", Section: "C", Category: "task", Score: 0.70}, + } + + output := FormatResultsJSON(results) + + var parsed []map[string]any + err := json.Unmarshal([]byte(output), &parsed) + require.NoError(t, err, "output should be valid JSON") + require.Len(t, parsed, 3) + + assert.Equal(t, "First.", parsed[0]["text"]) + assert.Equal(t, "Second.", parsed[1]["text"]) + assert.Equal(t, "Third.", parsed[2]["text"]) + }) + + t.Run("special characters are JSON-escaped in text", func(t *testing.T) { + results := []QueryResult{ + { + Text: "Line one\nLine two\twith tab and \"quotes\"", + Source: "special.md", + Section: "", + Category: "docs", + Score: 0.5, + }, + } + + output := FormatResultsJSON(results) + + var parsed []map[string]any + err := json.Unmarshal([]byte(output), &parsed) + require.NoError(t, err, "output should be valid JSON even with special characters") + assert.Equal(t, "Line one\nLine two\twith tab and \"quotes\"", parsed[0]["text"]) + }) + + t.Run("score formatted to four decimal places", func(t *testing.T) { + results := []QueryResult{ + {Text: "T.", Source: "s.md", Section: "", Category: "c", Score: 0.123456789}, + } + + output := FormatResultsJSON(results) + + assert.Contains(t, output, "0.1235") + }) +}