test: add Phase 1 pure-function unit tests (18.4% -> 38.8% coverage)
Add comprehensive unit tests for all pure functions that require no external services (Qdrant, Ollama). Coverage of testable functions now at 100% for FormatResults*, Default*Config, EmbedDimension, valueToGo, Model, ChunkID, and ChunkMarkdown edge cases. New test files: - query_test.go: FormatResultsText, FormatResultsContext, FormatResultsJSON, DefaultQueryConfig (18 tests) - ollama_test.go: DefaultOllamaConfig, EmbedDimension, Model (8 tests) - qdrant_test.go: DefaultQdrantConfig, pointIDToString, valueToGo, Point/SearchResult structs (24 tests) Extended chunk_test.go with edge cases: - Empty input, whitespace-only, single newline - Headers with no body content - Unicode/emoji text with rune-safe overlap verification - Very long single paragraph - Config boundary conditions (zero/negative size, overlap >= size) - Sequential chunk indexing - ChunkID rune truncation with multibyte characters - DefaultChunkConfig, DefaultIngestConfig Co-Authored-By: Charon <developers@lethean.io>
This commit is contained in:
parent
034b9da45f
commit
acb987a01d
4 changed files with 874 additions and 1 deletions
232
chunk_test.go
232
chunk_test.go
|
|
@ -110,7 +110,225 @@ func TestShouldProcess_Good_MarkdownFiles(t *testing.T) {
|
|||
assert.False(t, ShouldProcess("doc"))
|
||||
}
|
||||
|
||||
// Helper function
|
||||
// --- Additional chunk edge cases ---
|
||||
|
||||
func TestChunkMarkdown_Edge_EmptyInput(t *testing.T) {
|
||||
t.Run("empty string returns no chunks", func(t *testing.T) {
|
||||
chunks := ChunkMarkdown("", DefaultChunkConfig())
|
||||
assert.Empty(t, chunks)
|
||||
})
|
||||
|
||||
t.Run("whitespace only returns no chunks", func(t *testing.T) {
|
||||
chunks := ChunkMarkdown(" \n\n \t \n ", DefaultChunkConfig())
|
||||
assert.Empty(t, chunks)
|
||||
})
|
||||
|
||||
t.Run("single newline returns no chunks", func(t *testing.T) {
|
||||
chunks := ChunkMarkdown("\n", DefaultChunkConfig())
|
||||
assert.Empty(t, chunks)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Edge_OnlyHeadersNoContent(t *testing.T) {
|
||||
t.Run("single header with no body", func(t *testing.T) {
|
||||
text := "## Just a Header\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.Len(t, chunks, 1)
|
||||
assert.Equal(t, "Just a Header", chunks[0].Section)
|
||||
assert.Contains(t, chunks[0].Text, "Just a Header")
|
||||
})
|
||||
|
||||
t.Run("multiple headers with no body content", func(t *testing.T) {
|
||||
text := "## Header One\n\n## Header Two\n\n## Header Three\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
// Each header becomes its own section
|
||||
assert.GreaterOrEqual(t, len(chunks), 2, "should produce at least two chunks for separate sections")
|
||||
})
|
||||
|
||||
t.Run("header hierarchy with minimal content", func(t *testing.T) {
|
||||
text := "# Top Level\n\n## Sub Section\n\n### Sub Sub\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.NotEmpty(t, chunks, "should produce at least one chunk")
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Edge_UnicodeAndEmoji(t *testing.T) {
|
||||
t.Run("unicode text chunked correctly", func(t *testing.T) {
|
||||
text := "## Unicode Section\n\nThis section has unicode: \u00e9\u00e0\u00fc\u00f1\u00f6\u00e4\u00df \u4e16\u754c \u041f\u0440\u0438\u0432\u0435\u0442 \u0645\u0631\u062d\u0628\u0627\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.Len(t, chunks, 1)
|
||||
assert.Contains(t, chunks[0].Text, "\u00e9\u00e0\u00fc")
|
||||
assert.Contains(t, chunks[0].Text, "\u4e16\u754c")
|
||||
assert.Equal(t, "Unicode Section", chunks[0].Section)
|
||||
})
|
||||
|
||||
t.Run("emoji text chunked correctly", func(t *testing.T) {
|
||||
text := "## Emoji Section\n\nHello world! \U0001f600\U0001f680\U0001f30d\U0001f4da\n\nMore text with \u2764\ufe0f and \U0001f525 emojis.\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
assert.NotEmpty(t, chunks)
|
||||
assert.Contains(t, chunks[0].Text, "\U0001f600")
|
||||
assert.Equal(t, "Emoji Section", chunks[0].Section)
|
||||
})
|
||||
|
||||
t.Run("rune-safe overlap with multibyte characters", func(t *testing.T) {
|
||||
// Create text with multibyte characters that exceeds chunk size
|
||||
// Each CJK character is 3 bytes in UTF-8 but 1 rune
|
||||
para1 := "\u6d4b\u8bd5" + repeatString("\u4e16\u754c", 100) // ~200+ runes of CJK
|
||||
para2 := "\u8fd4\u56de" + repeatString("\u4f60\u597d", 100) // ~200+ runes of CJK
|
||||
text := "## CJK\n\n" + para1 + "\n\n" + para2 + "\n"
|
||||
|
||||
cfg := ChunkConfig{Size: 150, Overlap: 30}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
// Should not panic or produce corrupt text
|
||||
assert.NotEmpty(t, chunks)
|
||||
for _, chunk := range chunks {
|
||||
assert.NotEmpty(t, chunk.Text)
|
||||
// Verify no partial rune corruption by round-tripping through []rune
|
||||
runes := []rune(chunk.Text)
|
||||
assert.Equal(t, chunk.Text, string(runes), "text should survive rune round-trip without corruption")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Edge_VeryLongSingleParagraph(t *testing.T) {
|
||||
t.Run("long paragraph without headers splits into multiple chunks", func(t *testing.T) {
|
||||
// Create a very long single paragraph (no section headers)
|
||||
longText := repeatString("This is a very long sentence that should be split across multiple chunks. ", 100)
|
||||
|
||||
cfg := ChunkConfig{Size: 200, Overlap: 20}
|
||||
chunks := ChunkMarkdown(longText, cfg)
|
||||
|
||||
// The paragraph is one big block — chunking depends on paragraph splitting
|
||||
// Since there are no double newlines, the whole thing is one paragraph
|
||||
// The chunker should still produce at least one chunk
|
||||
assert.NotEmpty(t, chunks)
|
||||
for _, chunk := range chunks {
|
||||
assert.NotEmpty(t, chunk.Text)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("long paragraph with line breaks produces chunks", func(t *testing.T) {
|
||||
// Create long text with paragraph breaks so chunking can split
|
||||
var parts []string
|
||||
for i := 0; i < 50; i++ {
|
||||
parts = append(parts, "This is paragraph number that contains some meaningful text for testing purposes.")
|
||||
}
|
||||
longText := "## Long Content\n\n" + joinParagraphs(parts)
|
||||
|
||||
cfg := ChunkConfig{Size: 300, Overlap: 30}
|
||||
chunks := ChunkMarkdown(longText, cfg)
|
||||
|
||||
assert.Greater(t, len(chunks), 1, "long text should produce multiple chunks")
|
||||
for _, chunk := range chunks {
|
||||
assert.NotEmpty(t, chunk.Text)
|
||||
assert.Equal(t, "Long Content", chunk.Section)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Edge_ConfigBoundaries(t *testing.T) {
|
||||
t.Run("zero chunk size uses default 500", func(t *testing.T) {
|
||||
text := "## Section\n\nSome content.\n"
|
||||
cfg := ChunkConfig{Size: 0, Overlap: 0}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
assert.NotEmpty(t, chunks, "should still produce chunks with zero size (uses default)")
|
||||
})
|
||||
|
||||
t.Run("negative chunk size uses default 500", func(t *testing.T) {
|
||||
text := "## Section\n\nSome content.\n"
|
||||
cfg := ChunkConfig{Size: -1, Overlap: 0}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
assert.NotEmpty(t, chunks)
|
||||
})
|
||||
|
||||
t.Run("overlap equal to size resets to zero", func(t *testing.T) {
|
||||
// When overlap >= size, it resets to 0
|
||||
text := "## S\n\n" + repeatString("Word. ", 200)
|
||||
cfg := ChunkConfig{Size: 100, Overlap: 100}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
assert.NotEmpty(t, chunks)
|
||||
})
|
||||
|
||||
t.Run("negative overlap resets to zero", func(t *testing.T) {
|
||||
text := "## S\n\n" + repeatString("Word. ", 200)
|
||||
cfg := ChunkConfig{Size: 100, Overlap: -5}
|
||||
chunks := ChunkMarkdown(text, cfg)
|
||||
|
||||
assert.NotEmpty(t, chunks)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkMarkdown_Edge_ChunkIndexing(t *testing.T) {
|
||||
t.Run("chunk indices are sequential starting from zero", func(t *testing.T) {
|
||||
text := "## Section One\n\nContent one.\n\n## Section Two\n\nContent two.\n\n## Section Three\n\nContent three.\n"
|
||||
chunks := ChunkMarkdown(text, DefaultChunkConfig())
|
||||
|
||||
for i, chunk := range chunks {
|
||||
assert.Equal(t, i, chunk.Index, "chunk index should be sequential")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChunkID_Edge_LongText(t *testing.T) {
|
||||
t.Run("long text is truncated to first 100 runes for ID", func(t *testing.T) {
|
||||
longText := repeatString("a", 500)
|
||||
id1 := ChunkID("test.md", 0, longText)
|
||||
|
||||
// Same first 100 characters, different tail — should produce same ID
|
||||
longText2 := repeatString("a", 100) + repeatString("b", 400)
|
||||
id2 := ChunkID("test.md", 0, longText2)
|
||||
|
||||
assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical")
|
||||
})
|
||||
|
||||
t.Run("unicode text uses rune count not byte count", func(t *testing.T) {
|
||||
// 100 CJK characters (3 bytes each in UTF-8) = 100 runes
|
||||
runeText := repeatString("\u4e16", 100)
|
||||
id1 := ChunkID("test.md", 0, runeText)
|
||||
|
||||
// Same 100 CJK chars plus more — should produce same ID
|
||||
longerText := repeatString("\u4e16", 100) + repeatString("\u754c", 50)
|
||||
id2 := ChunkID("test.md", 0, longerText)
|
||||
|
||||
assert.Equal(t, id1, id2, "IDs should match when first 100 runes are identical (CJK)")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultChunkConfig(t *testing.T) {
|
||||
t.Run("returns expected default values", func(t *testing.T) {
|
||||
cfg := DefaultChunkConfig()
|
||||
|
||||
assert.Equal(t, 500, cfg.Size, "default chunk size should be 500")
|
||||
assert.Equal(t, 50, cfg.Overlap, "default chunk overlap should be 50")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultIngestConfig(t *testing.T) {
|
||||
t.Run("returns expected default values", func(t *testing.T) {
|
||||
cfg := DefaultIngestConfig()
|
||||
|
||||
assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs")
|
||||
assert.Equal(t, 100, cfg.BatchSize, "default batch size should be 100")
|
||||
assert.False(t, cfg.Recreate, "recreate should be false by default")
|
||||
assert.False(t, cfg.Verbose, "verbose should be false by default")
|
||||
assert.Empty(t, cfg.Directory, "directory should be empty by default")
|
||||
|
||||
// Nested ChunkConfig should match defaults
|
||||
assert.Equal(t, DefaultChunkConfig().Size, cfg.Chunk.Size)
|
||||
assert.Equal(t, DefaultChunkConfig().Overlap, cfg.Chunk.Overlap)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper: repeat a string n times
|
||||
func repeatString(s string, n int) string {
|
||||
result := ""
|
||||
for i := 0; i < n; i++ {
|
||||
|
|
@ -118,3 +336,15 @@ func repeatString(s string, n int) string {
|
|||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Helper: join paragraphs with double newlines
|
||||
func joinParagraphs(parts []string) string {
|
||||
result := ""
|
||||
for i, p := range parts {
|
||||
if i > 0 {
|
||||
result += "\n\n"
|
||||
}
|
||||
result += p
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
|
|
|||
88
ollama_test.go
Normal file
88
ollama_test.go
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- DefaultOllamaConfig tests ---
|
||||
|
||||
func TestDefaultOllamaConfig(t *testing.T) {
|
||||
t.Run("returns expected default values", func(t *testing.T) {
|
||||
cfg := DefaultOllamaConfig()
|
||||
|
||||
assert.Equal(t, "localhost", cfg.Host, "default host should be localhost")
|
||||
assert.Equal(t, 11434, cfg.Port, "default port should be 11434")
|
||||
assert.Equal(t, "nomic-embed-text", cfg.Model, "default model should be nomic-embed-text")
|
||||
})
|
||||
}
|
||||
|
||||
// --- EmbedDimension tests ---
|
||||
|
||||
func TestEmbedDimension(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected uint64
|
||||
}{
|
||||
{
|
||||
name: "nomic-embed-text returns 768",
|
||||
model: "nomic-embed-text",
|
||||
expected: 768,
|
||||
},
|
||||
{
|
||||
name: "mxbai-embed-large returns 1024",
|
||||
model: "mxbai-embed-large",
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "all-minilm returns 384",
|
||||
model: "all-minilm",
|
||||
expected: 384,
|
||||
},
|
||||
{
|
||||
name: "unknown model defaults to 768",
|
||||
model: "some-unknown-model",
|
||||
expected: 768,
|
||||
},
|
||||
{
|
||||
name: "empty model name defaults to 768",
|
||||
model: "",
|
||||
expected: 768,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Construct OllamaClient directly with the config to avoid needing a live server.
|
||||
// We only set the config field; client is nil but EmbedDimension does not use it.
|
||||
client := &OllamaClient{
|
||||
config: OllamaConfig{Model: tc.model},
|
||||
}
|
||||
|
||||
dim := client.EmbedDimension()
|
||||
assert.Equal(t, tc.expected, dim)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- Model tests ---
|
||||
|
||||
func TestModel(t *testing.T) {
|
||||
t.Run("returns the configured model name", func(t *testing.T) {
|
||||
client := &OllamaClient{
|
||||
config: OllamaConfig{Model: "nomic-embed-text"},
|
||||
}
|
||||
|
||||
assert.Equal(t, "nomic-embed-text", client.Model())
|
||||
})
|
||||
|
||||
t.Run("returns custom model name", func(t *testing.T) {
|
||||
client := &OllamaClient{
|
||||
config: OllamaConfig{Model: "custom-model"},
|
||||
}
|
||||
|
||||
assert.Equal(t, "custom-model", client.Model())
|
||||
})
|
||||
}
|
||||
281
qdrant_test.go
Normal file
281
qdrant_test.go
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/qdrant/go-client/qdrant"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// --- DefaultQdrantConfig tests ---
|
||||
|
||||
func TestDefaultQdrantConfig(t *testing.T) {
|
||||
t.Run("returns expected default values", func(t *testing.T) {
|
||||
cfg := DefaultQdrantConfig()
|
||||
|
||||
assert.Equal(t, "localhost", cfg.Host, "default host should be localhost")
|
||||
assert.Equal(t, 6334, cfg.Port, "default gRPC port should be 6334")
|
||||
assert.False(t, cfg.UseTLS, "TLS should be disabled by default")
|
||||
assert.Empty(t, cfg.APIKey, "API key should be empty by default")
|
||||
})
|
||||
}
|
||||
|
||||
// --- pointIDToString tests ---
|
||||
|
||||
func TestPointIDToString(t *testing.T) {
|
||||
t.Run("nil ID returns empty string", func(t *testing.T) {
|
||||
result := pointIDToString(nil)
|
||||
assert.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("numeric ID returns string representation", func(t *testing.T) {
|
||||
id := qdrant.NewIDNum(42)
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, "42", result)
|
||||
})
|
||||
|
||||
t.Run("numeric ID zero", func(t *testing.T) {
|
||||
id := qdrant.NewIDNum(0)
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, "0", result)
|
||||
})
|
||||
|
||||
t.Run("large numeric ID", func(t *testing.T) {
|
||||
id := qdrant.NewIDNum(18446744073709551615) // max uint64
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, "18446744073709551615", result)
|
||||
})
|
||||
|
||||
t.Run("UUID ID returns UUID string", func(t *testing.T) {
|
||||
uuid := "550e8400-e29b-41d4-a716-446655440000"
|
||||
id := qdrant.NewIDUUID(uuid)
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, uuid, result)
|
||||
})
|
||||
|
||||
t.Run("empty UUID returns empty string", func(t *testing.T) {
|
||||
id := qdrant.NewIDUUID("")
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("string ID via NewID returns the UUID", func(t *testing.T) {
|
||||
// NewID creates a UUID-type PointId
|
||||
uuid := "abc-123-def"
|
||||
id := qdrant.NewID(uuid)
|
||||
result := pointIDToString(id)
|
||||
assert.Equal(t, uuid, result)
|
||||
})
|
||||
}
|
||||
|
||||
// --- valueToGo tests ---
|
||||
|
||||
func TestValueToGo(t *testing.T) {
|
||||
t.Run("nil value returns nil", func(t *testing.T) {
|
||||
result := valueToGo(nil)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("string value", func(t *testing.T) {
|
||||
v := qdrant.NewValueString("hello world")
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, "hello world", result)
|
||||
})
|
||||
|
||||
t.Run("empty string value", func(t *testing.T) {
|
||||
v := qdrant.NewValueString("")
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("integer value", func(t *testing.T) {
|
||||
v := qdrant.NewValueInt(42)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, int64(42), result)
|
||||
})
|
||||
|
||||
t.Run("negative integer value", func(t *testing.T) {
|
||||
v := qdrant.NewValueInt(-100)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, int64(-100), result)
|
||||
})
|
||||
|
||||
t.Run("zero integer value", func(t *testing.T) {
|
||||
v := qdrant.NewValueInt(0)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, int64(0), result)
|
||||
})
|
||||
|
||||
t.Run("double value", func(t *testing.T) {
|
||||
v := qdrant.NewValueDouble(3.14)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, float64(3.14), result)
|
||||
})
|
||||
|
||||
t.Run("negative double value", func(t *testing.T) {
|
||||
v := qdrant.NewValueDouble(-2.718)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, float64(-2.718), result)
|
||||
})
|
||||
|
||||
t.Run("bool true value", func(t *testing.T) {
|
||||
v := qdrant.NewValueBool(true)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, true, result)
|
||||
})
|
||||
|
||||
t.Run("bool false value", func(t *testing.T) {
|
||||
v := qdrant.NewValueBool(false)
|
||||
result := valueToGo(v)
|
||||
assert.Equal(t, false, result)
|
||||
})
|
||||
|
||||
t.Run("null value returns nil", func(t *testing.T) {
|
||||
v := qdrant.NewValueNull()
|
||||
result := valueToGo(v)
|
||||
// NullValue is not handled by the switch, falls to default -> nil
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("list value with mixed types", func(t *testing.T) {
|
||||
v := qdrant.NewValueFromList(
|
||||
qdrant.NewValueString("alpha"),
|
||||
qdrant.NewValueInt(99),
|
||||
qdrant.NewValueBool(true),
|
||||
)
|
||||
result := valueToGo(v)
|
||||
|
||||
list, ok := result.([]any)
|
||||
assert.True(t, ok, "result should be a []any")
|
||||
assert.Len(t, list, 3)
|
||||
assert.Equal(t, "alpha", list[0])
|
||||
assert.Equal(t, int64(99), list[1])
|
||||
assert.Equal(t, true, list[2])
|
||||
})
|
||||
|
||||
t.Run("empty list value", func(t *testing.T) {
|
||||
v := qdrant.NewValueList(&qdrant.ListValue{Values: []*qdrant.Value{}})
|
||||
result := valueToGo(v)
|
||||
|
||||
list, ok := result.([]any)
|
||||
assert.True(t, ok, "result should be a []any")
|
||||
assert.Empty(t, list)
|
||||
})
|
||||
|
||||
t.Run("struct value with fields", func(t *testing.T) {
|
||||
v := qdrant.NewValueStruct(&qdrant.Struct{
|
||||
Fields: map[string]*qdrant.Value{
|
||||
"name": qdrant.NewValueString("test"),
|
||||
"age": qdrant.NewValueInt(25),
|
||||
},
|
||||
})
|
||||
result := valueToGo(v)
|
||||
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok, "result should be a map[string]any")
|
||||
assert.Equal(t, "test", m["name"])
|
||||
assert.Equal(t, int64(25), m["age"])
|
||||
})
|
||||
|
||||
t.Run("empty struct value", func(t *testing.T) {
|
||||
v := qdrant.NewValueStruct(&qdrant.Struct{
|
||||
Fields: map[string]*qdrant.Value{},
|
||||
})
|
||||
result := valueToGo(v)
|
||||
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok, "result should be a map[string]any")
|
||||
assert.Empty(t, m)
|
||||
})
|
||||
|
||||
t.Run("nested list within struct", func(t *testing.T) {
|
||||
v := qdrant.NewValueStruct(&qdrant.Struct{
|
||||
Fields: map[string]*qdrant.Value{
|
||||
"tags": qdrant.NewValueFromList(
|
||||
qdrant.NewValueString("go"),
|
||||
qdrant.NewValueString("rag"),
|
||||
),
|
||||
},
|
||||
})
|
||||
result := valueToGo(v)
|
||||
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok, "result should be a map[string]any")
|
||||
|
||||
tags, ok := m["tags"].([]any)
|
||||
assert.True(t, ok, "tags should be a []any")
|
||||
assert.Equal(t, []any{"go", "rag"}, tags)
|
||||
})
|
||||
|
||||
t.Run("nested struct within struct", func(t *testing.T) {
|
||||
inner := qdrant.NewValueStruct(&qdrant.Struct{
|
||||
Fields: map[string]*qdrant.Value{
|
||||
"key": qdrant.NewValueString("value"),
|
||||
},
|
||||
})
|
||||
v := qdrant.NewValueStruct(&qdrant.Struct{
|
||||
Fields: map[string]*qdrant.Value{
|
||||
"nested": inner,
|
||||
},
|
||||
})
|
||||
result := valueToGo(v)
|
||||
|
||||
m, ok := result.(map[string]any)
|
||||
assert.True(t, ok)
|
||||
nested, ok := m["nested"].(map[string]any)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value", nested["key"])
|
||||
})
|
||||
}
|
||||
|
||||
// --- Point struct tests ---
|
||||
|
||||
func TestPointStruct(t *testing.T) {
|
||||
t.Run("Point holds ID vector and payload", func(t *testing.T) {
|
||||
p := Point{
|
||||
ID: "test-id-123",
|
||||
Vector: []float32{0.1, 0.2, 0.3},
|
||||
Payload: map[string]any{
|
||||
"text": "hello",
|
||||
"source": "test.md",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "test-id-123", p.ID)
|
||||
assert.Len(t, p.Vector, 3)
|
||||
assert.Equal(t, float32(0.1), p.Vector[0])
|
||||
assert.Equal(t, "hello", p.Payload["text"])
|
||||
assert.Equal(t, "test.md", p.Payload["source"])
|
||||
})
|
||||
|
||||
t.Run("Point with empty payload", func(t *testing.T) {
|
||||
p := Point{
|
||||
ID: "empty",
|
||||
Vector: []float32{},
|
||||
Payload: map[string]any{},
|
||||
}
|
||||
|
||||
assert.Equal(t, "empty", p.ID)
|
||||
assert.Empty(t, p.Vector)
|
||||
assert.Empty(t, p.Payload)
|
||||
})
|
||||
}
|
||||
|
||||
// --- SearchResult struct tests ---
|
||||
|
||||
func TestSearchResultStruct(t *testing.T) {
|
||||
t.Run("SearchResult holds all fields", func(t *testing.T) {
|
||||
sr := SearchResult{
|
||||
ID: "result-1",
|
||||
Score: 0.95,
|
||||
Payload: map[string]any{
|
||||
"text": "some text",
|
||||
"category": "docs",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, "result-1", sr.ID)
|
||||
assert.Equal(t, float32(0.95), sr.Score)
|
||||
assert.Equal(t, "some text", sr.Payload["text"])
|
||||
})
|
||||
}
|
||||
274
query_test.go
Normal file
274
query_test.go
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
package rag
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- DefaultQueryConfig tests ---
|
||||
|
||||
func TestDefaultQueryConfig(t *testing.T) {
|
||||
t.Run("returns expected default values", func(t *testing.T) {
|
||||
cfg := DefaultQueryConfig()
|
||||
|
||||
assert.Equal(t, "hostuk-docs", cfg.Collection, "default collection should be hostuk-docs")
|
||||
assert.Equal(t, uint64(5), cfg.Limit, "default limit should be 5")
|
||||
assert.Equal(t, float32(0.5), cfg.Threshold, "default threshold should be 0.5")
|
||||
assert.Empty(t, cfg.Category, "default category should be empty")
|
||||
})
|
||||
}
|
||||
|
||||
// --- FormatResultsText tests ---
|
||||
|
||||
func TestFormatResultsText(t *testing.T) {
|
||||
t.Run("empty results returns no-results message", func(t *testing.T) {
|
||||
result := FormatResultsText(nil)
|
||||
assert.Equal(t, "No results found.", result)
|
||||
})
|
||||
|
||||
t.Run("empty slice returns no-results message", func(t *testing.T) {
|
||||
result := FormatResultsText([]QueryResult{})
|
||||
assert.Equal(t, "No results found.", result)
|
||||
})
|
||||
|
||||
t.Run("single result with all fields", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: "Some relevant text about Go.",
|
||||
Source: "docs/go-intro.md",
|
||||
Section: "Introduction",
|
||||
Category: "documentation",
|
||||
Score: 0.95,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsText(results)
|
||||
|
||||
assert.Contains(t, output, "Result 1")
|
||||
assert.Contains(t, output, "score: 0.95")
|
||||
assert.Contains(t, output, "Source: docs/go-intro.md")
|
||||
assert.Contains(t, output, "Section: Introduction")
|
||||
assert.Contains(t, output, "Category: documentation")
|
||||
assert.Contains(t, output, "Some relevant text about Go.")
|
||||
})
|
||||
|
||||
t.Run("section omitted when empty", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: "No section here.",
|
||||
Source: "test.md",
|
||||
Section: "",
|
||||
Category: "docs",
|
||||
Score: 0.80,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsText(results)
|
||||
|
||||
assert.NotContains(t, output, "Section:")
|
||||
})
|
||||
|
||||
t.Run("multiple results numbered correctly", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "First result.", Source: "a.md", Category: "docs", Score: 0.90},
|
||||
{Text: "Second result.", Source: "b.md", Category: "docs", Score: 0.85},
|
||||
{Text: "Third result.", Source: "c.md", Category: "docs", Score: 0.80},
|
||||
}
|
||||
|
||||
output := FormatResultsText(results)
|
||||
|
||||
assert.Contains(t, output, "Result 1")
|
||||
assert.Contains(t, output, "Result 2")
|
||||
assert.Contains(t, output, "Result 3")
|
||||
// Verify ordering: first result appears before second
|
||||
idx1 := strings.Index(output, "Result 1")
|
||||
idx2 := strings.Index(output, "Result 2")
|
||||
idx3 := strings.Index(output, "Result 3")
|
||||
assert.Less(t, idx1, idx2)
|
||||
assert.Less(t, idx2, idx3)
|
||||
})
|
||||
|
||||
t.Run("score formatted to two decimal places", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "Test.", Source: "s.md", Category: "c", Score: 0.123456},
|
||||
}
|
||||
|
||||
output := FormatResultsText(results)
|
||||
|
||||
assert.Contains(t, output, "score: 0.12")
|
||||
})
|
||||
}
|
||||
|
||||
// --- FormatResultsContext tests ---
|
||||
|
||||
func TestFormatResultsContext(t *testing.T) {
|
||||
t.Run("empty results returns empty string", func(t *testing.T) {
|
||||
result := FormatResultsContext(nil)
|
||||
assert.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("empty slice returns empty string", func(t *testing.T) {
|
||||
result := FormatResultsContext([]QueryResult{})
|
||||
assert.Equal(t, "", result)
|
||||
})
|
||||
|
||||
t.Run("wraps output in retrieved_context tags", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "Hello world.", Source: "test.md", Section: "Intro", Category: "docs", Score: 0.9},
|
||||
}
|
||||
|
||||
output := FormatResultsContext(results)
|
||||
|
||||
assert.True(t, strings.HasPrefix(output, "<retrieved_context>\n"),
|
||||
"output should start with <retrieved_context> tag")
|
||||
assert.True(t, strings.HasSuffix(output, "</retrieved_context>"),
|
||||
"output should end with </retrieved_context> tag")
|
||||
})
|
||||
|
||||
t.Run("wraps each result in document tags with attributes", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: "Content here.",
|
||||
Source: "file.md",
|
||||
Section: "My Section",
|
||||
Category: "documentation",
|
||||
Score: 0.88,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsContext(results)
|
||||
|
||||
assert.Contains(t, output, `source="file.md"`)
|
||||
assert.Contains(t, output, `section="My Section"`)
|
||||
assert.Contains(t, output, `category="documentation"`)
|
||||
assert.Contains(t, output, "Content here.")
|
||||
assert.Contains(t, output, "</document>")
|
||||
})
|
||||
|
||||
t.Run("escapes XML special characters in attributes and text", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: `Text with <tags> & "quotes" in it.`,
|
||||
Source: `path/with<special>&chars.md`,
|
||||
Section: `Section "One"`,
|
||||
Category: "docs",
|
||||
Score: 0.75,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsContext(results)
|
||||
|
||||
// The html.EscapeString function escapes <, >, &, " and '
|
||||
assert.Contains(t, output, "<tags>")
|
||||
assert.Contains(t, output, "&")
|
||||
assert.Contains(t, output, ""quotes"")
|
||||
// Source attribute should also be escaped
|
||||
assert.Contains(t, output, "path/with<special>&chars.md")
|
||||
})
|
||||
|
||||
t.Run("multiple results each wrapped in document tags", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "First.", Source: "a.md", Section: "", Category: "docs", Score: 0.9},
|
||||
{Text: "Second.", Source: "b.md", Section: "", Category: "docs", Score: 0.8},
|
||||
}
|
||||
|
||||
output := FormatResultsContext(results)
|
||||
|
||||
// Count document tags
|
||||
assert.Equal(t, 2, strings.Count(output, "<document "))
|
||||
assert.Equal(t, 2, strings.Count(output, "</document>"))
|
||||
})
|
||||
}
|
||||
|
||||
// --- FormatResultsJSON tests ---
|
||||
|
||||
func TestFormatResultsJSON(t *testing.T) {
|
||||
t.Run("empty results returns empty JSON array", func(t *testing.T) {
|
||||
result := FormatResultsJSON(nil)
|
||||
assert.Equal(t, "[]", result)
|
||||
})
|
||||
|
||||
t.Run("empty slice returns empty JSON array", func(t *testing.T) {
|
||||
result := FormatResultsJSON([]QueryResult{})
|
||||
assert.Equal(t, "[]", result)
|
||||
})
|
||||
|
||||
t.Run("single result produces valid JSON", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: "Test content.",
|
||||
Source: "test.md",
|
||||
Section: "Intro",
|
||||
Category: "docs",
|
||||
Score: 0.9234,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsJSON(results)
|
||||
|
||||
// Verify it parses as valid JSON
|
||||
var parsed []map[string]any
|
||||
err := json.Unmarshal([]byte(output), &parsed)
|
||||
require.NoError(t, err, "output should be valid JSON")
|
||||
require.Len(t, parsed, 1)
|
||||
|
||||
assert.Equal(t, "test.md", parsed[0]["source"])
|
||||
assert.Equal(t, "Intro", parsed[0]["section"])
|
||||
assert.Equal(t, "docs", parsed[0]["category"])
|
||||
assert.Equal(t, "Test content.", parsed[0]["text"])
|
||||
// Score is formatted to 4 decimal places
|
||||
assert.InDelta(t, 0.9234, parsed[0]["score"], 0.0001)
|
||||
})
|
||||
|
||||
t.Run("multiple results produce valid JSON array", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "First.", Source: "a.md", Section: "A", Category: "docs", Score: 0.95},
|
||||
{Text: "Second.", Source: "b.md", Section: "B", Category: "code", Score: 0.80},
|
||||
{Text: "Third.", Source: "c.md", Section: "C", Category: "task", Score: 0.70},
|
||||
}
|
||||
|
||||
output := FormatResultsJSON(results)
|
||||
|
||||
var parsed []map[string]any
|
||||
err := json.Unmarshal([]byte(output), &parsed)
|
||||
require.NoError(t, err, "output should be valid JSON")
|
||||
require.Len(t, parsed, 3)
|
||||
|
||||
assert.Equal(t, "First.", parsed[0]["text"])
|
||||
assert.Equal(t, "Second.", parsed[1]["text"])
|
||||
assert.Equal(t, "Third.", parsed[2]["text"])
|
||||
})
|
||||
|
||||
t.Run("special characters are JSON-escaped in text", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{
|
||||
Text: "Line one\nLine two\twith tab and \"quotes\"",
|
||||
Source: "special.md",
|
||||
Section: "",
|
||||
Category: "docs",
|
||||
Score: 0.5,
|
||||
},
|
||||
}
|
||||
|
||||
output := FormatResultsJSON(results)
|
||||
|
||||
var parsed []map[string]any
|
||||
err := json.Unmarshal([]byte(output), &parsed)
|
||||
require.NoError(t, err, "output should be valid JSON even with special characters")
|
||||
assert.Equal(t, "Line one\nLine two\twith tab and \"quotes\"", parsed[0]["text"])
|
||||
})
|
||||
|
||||
t.Run("score formatted to four decimal places", func(t *testing.T) {
|
||||
results := []QueryResult{
|
||||
{Text: "T.", Source: "s.md", Section: "", Category: "c", Score: 0.123456789},
|
||||
}
|
||||
|
||||
output := FormatResultsJSON(results)
|
||||
|
||||
assert.Contains(t, output, "0.1235")
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue