Add dual-backend support for KDTree with benchmarks and documentation updates
This commit is contained in:
parent
5d1ee3f0ea
commit
3c83fc38e4
13 changed files with 1138 additions and 36 deletions
71
.github/workflows/ci.yml
vendored
71
.github/workflows/ci.yml
vendored
|
|
@ -39,6 +39,16 @@ jobs:
|
|||
- name: CI checks (lint, tests, coverage, etc.)
|
||||
run: make ci
|
||||
|
||||
- name: Benchmarks (linear)
|
||||
run: go test -bench . -benchmem -run=^$ ./... | tee bench-linear.txt
|
||||
|
||||
- name: Upload benchmarks (linear)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bench-linear
|
||||
path: bench-linear.txt
|
||||
if-no-files-found: error
|
||||
|
||||
- name: Build WebAssembly module
|
||||
run: make wasm-build
|
||||
|
||||
|
|
@ -72,3 +82,64 @@ jobs:
|
|||
name: npm-poindexter-wasm-tarball
|
||||
if-no-files-found: error
|
||||
path: ${{ steps.npm_pack.outputs.tarball }}
|
||||
|
||||
build-test-gonum:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.23.x'
|
||||
|
||||
- name: Install extra tools
|
||||
run: |
|
||||
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||
|
||||
- name: Go env
|
||||
run: go env
|
||||
|
||||
- name: Lint
|
||||
run: golangci-lint run
|
||||
|
||||
- name: Build (gonum tag)
|
||||
run: go build -tags=gonum ./...
|
||||
|
||||
- name: Unit tests + race + coverage (gonum tag)
|
||||
run: go test -tags=gonum -race -coverpkg=./... -coverprofile=coverage-gonum.out -covermode=atomic ./...
|
||||
|
||||
- name: Fuzz (10s per fuzz test, gonum tag)
|
||||
run: |
|
||||
set -e
|
||||
for pkg in $(go list ./...); do
|
||||
FUZZES=$(go test -tags=gonum -list '^Fuzz' "$pkg" | grep '^Fuzz' || true)
|
||||
if [ -z "$FUZZES" ]; then
|
||||
echo "==> Skipping $pkg (no fuzz targets)"
|
||||
continue
|
||||
fi
|
||||
for fz in $FUZZES; do
|
||||
echo "==> Fuzzing $pkg :: $fz for 10s"
|
||||
go test -tags=gonum -run=NONE -fuzz=^${fz}$ -fuzztime=10s "$pkg"
|
||||
done
|
||||
done
|
||||
|
||||
- name: Benchmarks (gonum tag)
|
||||
run: go test -tags=gonum -bench . -benchmem -run=^$ ./... | tee bench-gonum.txt
|
||||
|
||||
- name: Upload coverage (gonum)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-gonum
|
||||
path: coverage-gonum.out
|
||||
if-no-files-found: error
|
||||
|
||||
- name: Upload benchmarks (gonum)
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: bench-gonum
|
||||
path: bench-gonum.txt
|
||||
if-no-files-found: error
|
||||
|
|
|
|||
|
|
@ -6,6 +6,12 @@ The format is based on Keep a Changelog and this project adheres to Semantic Ver
|
|||
|
||||
## [Unreleased]
|
||||
### Added
|
||||
- Dual-backend benchmarks (Linear vs Gonum) with deterministic datasets (uniform/clustered) in 2D/4D for N=1k/10k; artifacts uploaded in CI as `bench-linear.txt` and `bench-gonum.txt`.
|
||||
- Documentation: Performance guide updated to cover backend selection, how to run both backends, CI artifact links, and guidance on when each backend is preferred.
|
||||
- Documentation: Performance guide now includes a Sample results table sourced from a recent local run.
|
||||
- Documentation: README gained a “Backend selection” section with default behavior, build tag usage, overrides, and supported metrics notes.
|
||||
- Documentation: API reference (`docs/api.md`) now documents `KDBackend`, `WithBackend`, default selection, and supported metrics for the optimized backend.
|
||||
- Examples: Added `examples/wasm-browser/` minimal browser demo (ESM + HTML) for the WASM build.
|
||||
- pkg.go.dev Examples: `ExampleNewKDTreeFromDim_Insert`, `ExampleKDTree_TiesBehavior`, `ExampleKDTree_Radius_none`.
|
||||
- Lint: enable `errcheck` in `.golangci.yml` with test-file exclusion to reduce noise.
|
||||
- CI: enable module cache in `actions/setup-go` to speed up workflows.
|
||||
|
|
|
|||
20
README.md
20
README.md
|
|
@ -71,14 +71,32 @@ Explore runnable examples in the repository:
|
|||
- examples/kdtree_2d_ping_hop
|
||||
- examples/kdtree_3d_ping_hop_geo
|
||||
- examples/kdtree_4d_ping_hop_geo_score
|
||||
- examples/wasm-browser (browser demo using the ESM loader)
|
||||
|
||||
### KDTree performance and notes
|
||||
- Current KDTree queries are O(n) linear scans, which are great for small-to-medium datasets or low-latency prototyping. For 1e5+ points and low/medium dimensions, consider swapping the internal engine to `gonum.org/v1/gonum/spatial/kdtree` (the API here is compatible by design).
|
||||
- Dual backend support: Linear (always available) and an optimized KD backend enabled when building with `-tags=gonum`. Linear is the default; with the `gonum` tag, the optimized backend becomes the default.
|
||||
- Complexity: Linear backend is O(n) per query. Optimized KD backend is typically sub-linear on prunable datasets and dims ≤ ~8, especially as N grows (≥10k–100k).
|
||||
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
||||
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
||||
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
||||
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
||||
|
||||
### Backend selection
|
||||
- Default backend is Linear. If you build with `-tags=gonum`, the default becomes the optimized KD backend.
|
||||
- You can override per tree at construction:
|
||||
|
||||
```go
|
||||
// Force Linear (always available)
|
||||
kdt1, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||
|
||||
// Force Gonum (requires build tag)
|
||||
kdt2, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||
```
|
||||
|
||||
- Supported metrics in the optimized backend: Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||
- Cosine and Weighted-Cosine currently run on the Linear backend.
|
||||
- See the Performance guide for measured comparisons and when to choose which backend.
|
||||
|
||||
#### Choosing a metric (quick tips)
|
||||
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
||||
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
||||
|
|
|
|||
180
bench_kdtree_dual_test.go
Normal file
180
bench_kdtree_dual_test.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package poindexter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// deterministicRand returns a rand.Rand with a fixed seed for reproducible datasets.
|
||||
func deterministicRand() *rand.Rand { return rand.New(rand.NewSource(42)) }
|
||||
|
||||
func makeUniformPoints(n, dim int) []KDPoint[int] {
|
||||
r := deterministicRand()
|
||||
pts := make([]KDPoint[int], n)
|
||||
for i := 0; i < n; i++ {
|
||||
coords := make([]float64, dim)
|
||||
for d := 0; d < dim; d++ {
|
||||
coords[d] = r.Float64()
|
||||
}
|
||||
pts[i] = KDPoint[int]{ID: fmt.Sprint(i), Coords: coords, Value: i}
|
||||
}
|
||||
return pts
|
||||
}
|
||||
|
||||
// makeClusteredPoints creates n points around c clusters with small variance.
|
||||
func makeClusteredPoints(n, dim, c int) []KDPoint[int] {
|
||||
if c <= 0 {
|
||||
c = 1
|
||||
}
|
||||
r := deterministicRand()
|
||||
centers := make([][]float64, c)
|
||||
for i := 0; i < c; i++ {
|
||||
centers[i] = make([]float64, dim)
|
||||
for d := 0; d < dim; d++ {
|
||||
centers[i][d] = r.Float64()
|
||||
}
|
||||
}
|
||||
pts := make([]KDPoint[int], n)
|
||||
for i := 0; i < n; i++ {
|
||||
coords := make([]float64, dim)
|
||||
cent := centers[r.Intn(c)]
|
||||
for d := 0; d < dim; d++ {
|
||||
// small gaussian noise around center (Box-Muller)
|
||||
u1 := r.Float64()
|
||||
u2 := r.Float64()
|
||||
z := (rand.NormFloat64()) // uses global; fine for test speed
|
||||
_ = u1
|
||||
_ = u2
|
||||
coords[d] = cent[d] + 0.03*z
|
||||
if coords[d] < 0 {
|
||||
coords[d] = 0
|
||||
} else if coords[d] > 1 {
|
||||
coords[d] = 1
|
||||
}
|
||||
}
|
||||
pts[i] = KDPoint[int]{ID: fmt.Sprint(i), Coords: coords, Value: i}
|
||||
}
|
||||
return pts
|
||||
}
|
||||
|
||||
func benchNearestBackend(b *testing.B, n, dim int, backend KDBackend, uniform bool, clusters int) {
|
||||
var pts []KDPoint[int]
|
||||
if uniform {
|
||||
pts = makeUniformPoints(n, dim)
|
||||
} else {
|
||||
pts = makeClusteredPoints(n, dim, clusters)
|
||||
}
|
||||
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||
q := make([]float64, dim)
|
||||
for i := range q {
|
||||
q[i] = 0.5
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _, _ = tr.Nearest(q)
|
||||
}
|
||||
}
|
||||
|
||||
func benchKNNBackend(b *testing.B, n, dim, k int, backend KDBackend, uniform bool, clusters int) {
|
||||
var pts []KDPoint[int]
|
||||
if uniform {
|
||||
pts = makeUniformPoints(n, dim)
|
||||
} else {
|
||||
pts = makeClusteredPoints(n, dim, clusters)
|
||||
}
|
||||
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||
q := make([]float64, dim)
|
||||
for i := range q {
|
||||
q[i] = 0.5
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = tr.KNearest(q, k)
|
||||
}
|
||||
}
|
||||
|
||||
func benchRadiusBackend(b *testing.B, n, dim int, r float64, backend KDBackend, uniform bool, clusters int) {
|
||||
var pts []KDPoint[int]
|
||||
if uniform {
|
||||
pts = makeUniformPoints(n, dim)
|
||||
} else {
|
||||
pts = makeClusteredPoints(n, dim, clusters)
|
||||
}
|
||||
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||
q := make([]float64, dim)
|
||||
for i := range q {
|
||||
q[i] = 0.5
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = tr.Radius(q, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Uniform 2D/4D, Linear vs Gonum (opt-in via build tag; falls back to linear if not available)
|
||||
func BenchmarkNearest_Linear_Uniform_1k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 2, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Uniform_1k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 2, BackendGonum, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Linear_Uniform_10k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 2, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 2, BackendGonum, true, 0)
|
||||
}
|
||||
|
||||
func BenchmarkNearest_Linear_Uniform_1k_4D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 4, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Uniform_1k_4D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 4, BackendGonum, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Linear_Uniform_10k_4D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 4, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Uniform_10k_4D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 4, BackendGonum, true, 0)
|
||||
}
|
||||
|
||||
// Clustered 2D/4D (3 clusters)
|
||||
func BenchmarkNearest_Linear_Clustered_1k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 2, BackendLinear, false, 3)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Clustered_1k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 1_000, 2, BackendGonum, false, 3)
|
||||
}
|
||||
func BenchmarkNearest_Linear_Clustered_10k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 2, BackendLinear, false, 3)
|
||||
}
|
||||
func BenchmarkNearest_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||
benchNearestBackend(b, 10_000, 2, BackendGonum, false, 3)
|
||||
}
|
||||
|
||||
func BenchmarkKNN10_Linear_Uniform_10k_2D(b *testing.B) {
|
||||
benchKNNBackend(b, 10_000, 2, 10, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkKNN10_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||
benchKNNBackend(b, 10_000, 2, 10, BackendGonum, true, 0)
|
||||
}
|
||||
func BenchmarkKNN10_Linear_Clustered_10k_2D(b *testing.B) {
|
||||
benchKNNBackend(b, 10_000, 2, 10, BackendLinear, false, 3)
|
||||
}
|
||||
func BenchmarkKNN10_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||
benchKNNBackend(b, 10_000, 2, 10, BackendGonum, false, 3)
|
||||
}
|
||||
|
||||
func BenchmarkRadiusMid_Linear_Uniform_10k_2D(b *testing.B) {
|
||||
benchRadiusBackend(b, 10_000, 2, 0.5, BackendLinear, true, 0)
|
||||
}
|
||||
func BenchmarkRadiusMid_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||
benchRadiusBackend(b, 10_000, 2, 0.5, BackendGonum, true, 0)
|
||||
}
|
||||
func BenchmarkRadiusMid_Linear_Clustered_10k_2D(b *testing.B) {
|
||||
benchRadiusBackend(b, 10_000, 2, 0.5, BackendLinear, false, 3)
|
||||
}
|
||||
func BenchmarkRadiusMid_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||
benchRadiusBackend(b, 10_000, 2, 0.5, BackendGonum, false, 3)
|
||||
}
|
||||
57
docs/api.md
57
docs/api.md
|
|
@ -494,3 +494,60 @@ Notes:
|
|||
- If `min==max` for an axis, normalized value is `0` for that axis.
|
||||
- `invert[i]` flips the normalized axis as `1 - n` before applying `weights[i]`.
|
||||
- These helpers mirror `Build2D/3D/4D`, but use your provided `NormStats` instead of recomputing from the items slice.
|
||||
|
||||
|
||||
---
|
||||
|
||||
## KDTree Backend selection
|
||||
|
||||
Poindexter provides two internal backends for KDTree queries:
|
||||
|
||||
- `linear`: always available; performs O(n) scans for `Nearest`, `KNearest`, and `Radius`.
|
||||
- `gonum`: optimized KD backend compiled when you build with the `gonum` build tag; typically sub-linear on prunable datasets and modest dimensions.
|
||||
|
||||
### Types and options
|
||||
|
||||
```go
|
||||
// KDBackend selects the internal engine used by KDTree.
|
||||
type KDBackend string
|
||||
|
||||
const (
|
||||
BackendLinear KDBackend = "linear"
|
||||
BackendGonum KDBackend = "gonum"
|
||||
)
|
||||
|
||||
// WithBackend selects the internal KDTree backend ("linear" or "gonum").
|
||||
// If the requested backend is unavailable (e.g., missing build tag), the constructor
|
||||
// falls back to the linear backend.
|
||||
func WithBackend(b KDBackend) KDOption
|
||||
```
|
||||
|
||||
### Default selection
|
||||
|
||||
- Default is `linear`.
|
||||
- If you build your project with `-tags=gonum`, the default becomes `gonum`.
|
||||
|
||||
### Usage examples
|
||||
|
||||
```go
|
||||
// Default metric is Euclidean; you can override with WithMetric.
|
||||
pts := []poindexter.KDPoint[string]{
|
||||
{ID: "A", Coords: []float64{0, 0}},
|
||||
{ID: "B", Coords: []float64{1, 0}},
|
||||
}
|
||||
|
||||
// Force Linear (always available)
|
||||
lin, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||
_, _, _ = lin.Nearest([]float64{0.9, 0.1})
|
||||
|
||||
// Force Gonum (requires building with: go build -tags=gonum)
|
||||
gon, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||
_, _, _ = gon.Nearest([]float64{0.9, 0.1})
|
||||
```
|
||||
|
||||
### Supported metrics in the optimized backend
|
||||
|
||||
- Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||
- Cosine and Weighted-Cosine currently use the Linear backend.
|
||||
|
||||
See also the Performance guide for measured comparisons and guidance: `docs/perf.md`.
|
||||
120
docs/perf.md
120
docs/perf.md
|
|
@ -1,51 +1,133 @@
|
|||
# Performance: KDTree benchmarks and guidance
|
||||
|
||||
This page summarizes how to measure KDTree performance in this repository and when to consider switching the internal engine to `gonum.org/v1/gonum/spatial/kdtree` for large datasets.
|
||||
This page summarizes how to measure KDTree performance in this repository and how to compare the two internal backends (Linear vs Gonum) that you can select at build/runtime.
|
||||
|
||||
## How benchmarks are organized
|
||||
|
||||
- Micro-benchmarks live in `bench_kdtree_test.go` and cover:
|
||||
- Micro-benchmarks live in `bench_kdtree_test.go` and `bench_kdtree_dual_test.go` and cover:
|
||||
- `Nearest` in 2D and 4D with N = 1k, 10k
|
||||
- `KNearest(k=10)` in 2D with N = 1k, 10k
|
||||
- `Radius` (mid radius) in 2D with N = 1k, 10k
|
||||
- All benchmarks operate in normalized [0,1] spaces and use the current linear-scan implementation.
|
||||
- `KNearest(k=10)` in 2D/4D with N = 1k, 10k
|
||||
- `Radius` (mid radius r≈0.5 after normalization) in 2D/4D with N = 1k, 10k
|
||||
- Datasets: Uniform and 3-cluster synthetic generators in normalized [0,1] spaces.
|
||||
- Backends: Linear (always available) and Gonum (enabled when built with `-tags=gonum`).
|
||||
|
||||
Run them locally:
|
||||
|
||||
```bash
|
||||
# Linear backend (default)
|
||||
go test -bench . -benchmem -run=^$ ./...
|
||||
|
||||
# Gonum backend (optimized KD; requires build tag)
|
||||
go test -tags=gonum -bench . -benchmem -run=^$ ./...
|
||||
```
|
||||
|
||||
GitHub Actions publishes benchmark artifacts for Go 1.23 on every push/PR. Look for artifacts named `bench-<go-version>.txt` in the CI run.
|
||||
GitHub Actions publishes benchmark artifacts on every push/PR:
|
||||
- Linear job: artifact `bench-linear.txt`
|
||||
- Gonum job: artifact `bench-gonum.txt`
|
||||
|
||||
## Backend selection and defaults
|
||||
|
||||
- Default backend is Linear.
|
||||
- If you build with `-tags=gonum`, the default switches to the optimized KD backend.
|
||||
- You can override at runtime:
|
||||
|
||||
```
|
||||
// Force Linear
|
||||
kdt, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||
// Force Gonum (requires build tag)
|
||||
kdt, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||
```
|
||||
|
||||
Supported metrics in the optimized backend: L2 (Euclidean), L1 (Manhattan), L∞ (Chebyshev). Cosine/Weighted-Cosine currently use the Linear backend.
|
||||
|
||||
## What to expect (rule of thumb)
|
||||
|
||||
- Time complexity is O(n) per query in the current implementation.
|
||||
- For small-to-medium datasets (up to ~10k points), linear scans are often fast enough, especially for low dimensionality (≤4) and if queries are batched efficiently.
|
||||
- For larger datasets (≥100k) and low/medium dimensions (≤8), a true KD-tree (like Gonum’s) often yields sub-linear queries and significantly lower latency.
|
||||
- Linear backend: O(n) per query; fast for small-to-medium datasets (≤10k), especially in low dims (≤4).
|
||||
- Gonum backend: typically sub-linear for prunable datasets and dims ≤ ~8, with noticeable gains as N grows (≥10k–100k), especially on uniform or moderately clustered data and moderate radii.
|
||||
- For large radii (many points within r) or highly correlated/pathological data, pruning may be less effective and behavior approaches O(n) even with KD-trees.
|
||||
|
||||
## Interpreting results
|
||||
|
||||
Benchmarks output something like:
|
||||
|
||||
```
|
||||
BenchmarkNearest_10k_4D-8 50000 23,000 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_10k_4D_Gonum_Uniform-8 50000 12,300 ns/op 0 B/op 0 allocs/op
|
||||
```
|
||||
|
||||
- `ns/op`: lower is better (nanoseconds per operation)
|
||||
- `B/op` and `allocs/op`: memory behavior; fewer is better
|
||||
|
||||
Because `KNearest` sorts by distance, you should expect additional cost over `Nearest`. `Radius` cost depends on how many points fall within the radius; tighter radii usually run faster.
|
||||
- `KNearest` incurs extra work due to sorting; `Radius` cost scales with the number of hits.
|
||||
|
||||
## Improving performance
|
||||
|
||||
- Prefer Euclidean (L2) over metrics that require extra branching for CPU pipelines, unless your policy prefers otherwise.
|
||||
- Normalize and weight features once; reuse coordinates across queries.
|
||||
- Batch queries to amortize overhead of data locality and caches.
|
||||
- Consider a backend swap to Gonum’s KD-tree for large N (we plan to add a `WithBackend("gonum")` option).
|
||||
- Normalize and weight features once; reuse across queries (see `Build*WithStats` helpers).
|
||||
- Choose a metric aligned with your policy: L2 usually a solid default; L1 for per-axis penalties; L∞ for hard-threshold dominated objectives.
|
||||
- Batch queries to benefit from CPU caches.
|
||||
- Prefer the Gonum backend for larger N and dims ≤ ~8; stick to Linear for tiny datasets or when using Cosine metrics.
|
||||
|
||||
## Reproducing and tracking performance
|
||||
|
||||
- Local: run `go test -bench . -benchmem -run=^$ ./...`
|
||||
- CI: download `bench-*.txt` artifacts from the latest workflow run
|
||||
- Optional: we can add historical trend graphs via Codecov or Benchstat integration if desired.
|
||||
- Local (Linear): `go test -bench . -benchmem -run=^$ ./...`
|
||||
- Local (Gonum): `go test -tags=gonum -bench . -benchmem -run=^$ ./...`
|
||||
- CI artifacts: download `bench-linear.txt` and `bench-gonum.txt` from the latest workflow run.
|
||||
- Optional: add historical trend graphs via Benchstat or Codecov integration.
|
||||
|
||||
## Sample results (from a recent local run)
|
||||
|
||||
Results vary by machine, Go version, and dataset seed. The following run was captured locally and is provided as a reference point.
|
||||
|
||||
- Machine: darwin/arm64, Apple M3 Ultra
|
||||
- Package: `github.com/Snider/Poindexter`
|
||||
- Command: `go test -bench . -benchmem -run=^$ ./... | tee bench.txt`
|
||||
|
||||
Full output:
|
||||
|
||||
```
|
||||
goos: darwin
|
||||
goarch: arm64
|
||||
pkg: github.com/Snider/Poindexter
|
||||
BenchmarkNearest_Linear_Uniform_1k_2D-32 409321 3001 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Uniform_1k_2D-32 413823 2888 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Linear_Uniform_10k_2D-32 43053 27809 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Uniform_10k_2D-32 42996 27936 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Linear_Uniform_1k_4D-32 326492 3746 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Uniform_1k_4D-32 338983 3857 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Linear_Uniform_10k_4D-32 35661 32985 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Uniform_10k_4D-32 35678 33388 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Linear_Clustered_1k_2D-32 425220 2874 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Clustered_1k_2D-32 420080 2849 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Linear_Clustered_10k_2D-32 43242 27776 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_Gonum_Clustered_10k_2D-32 42392 27889 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkKNN10_Linear_Uniform_10k_2D-32 1206 977599 ns/op 164492 B/op 6 allocs/op
|
||||
BenchmarkKNN10_Gonum_Uniform_10k_2D-32 1239 972501 ns/op 164488 B/op 6 allocs/op
|
||||
BenchmarkKNN10_Linear_Clustered_10k_2D-32 1219 973242 ns/op 164492 B/op 6 allocs/op
|
||||
BenchmarkKNN10_Gonum_Clustered_10k_2D-32 1214 971017 ns/op 164488 B/op 6 allocs/op
|
||||
BenchmarkRadiusMid_Linear_Uniform_10k_2D-32 1279 917692 ns/op 947529 B/op 23 allocs/op
|
||||
BenchmarkRadiusMid_Gonum_Uniform_10k_2D-32 1299 918176 ns/op 947529 B/op 23 allocs/op
|
||||
BenchmarkRadiusMid_Linear_Clustered_10k_2D-32 1059 1123281 ns/op 1217866 B/op 24 allocs/op
|
||||
BenchmarkRadiusMid_Gonum_Clustered_10k_2D-32 1063 1149507 ns/op 1217871 B/op 24 allocs/op
|
||||
BenchmarkNearest_1k_2D-32 401595 2964 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_10k_2D-32 42129 28229 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_1k_4D-32 365626 3642 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkNearest_10k_4D-32 36298 33176 ns/op 0 B/op 0 allocs/op
|
||||
BenchmarkKNearest10_1k_2D-32 20348 59568 ns/op 17032 B/op 6 allocs/op
|
||||
BenchmarkKNearest10_10k_2D-32 1224 969093 ns/op 164488 B/op 6 allocs/op
|
||||
BenchmarkRadiusMid_1k_2D-32 21867 53273 ns/op 77512 B/op 16 allocs/op
|
||||
BenchmarkRadiusMid_10k_2D-32 1302 933791 ns/op 955720 B/op 23 allocs/op
|
||||
PASS
|
||||
ok github.com/Snider/Poindexter 40.102s
|
||||
PASS
|
||||
ok github.com/Snider/Poindexter/examples/dht_ping_1d 0.348s
|
||||
PASS
|
||||
ok github.com/Snider/Poindexter/examples/kdtree_2d_ping_hop 0.266s
|
||||
PASS
|
||||
ok github.com/Snider/Poindexter/examples/kdtree_3d_ping_hop_geo 0.272s
|
||||
PASS
|
||||
ok github.com/Snider/Poindexter/examples/kdtree_4d_ping_hop_geo_score 0.269s
|
||||
```
|
||||
|
||||
Notes:
|
||||
- The first block shows dual-backend benchmarks (Linear vs Gonum) for uniform and clustered datasets at 2D/4D with N=1k/10k.
|
||||
- The final block includes the legacy single-backend benchmarks for additional sizes; both are useful for comparison.
|
||||
|
||||
To compare against the optimized KD backend explicitly, build with `-tags=gonum` and/or download `bench-gonum.txt` from CI artifacts.
|
||||
|
|
|
|||
48
docs/wasm.md
48
docs/wasm.md
|
|
@ -9,6 +9,36 @@ Poindexter ships a browser build compiled to WebAssembly along with a small JS l
|
|||
- `npm/poindexter-wasm/loader.js` — ESM loader that instantiates the WASM and exposes a friendly API
|
||||
- `npm/poindexter-wasm/index.d.ts` — TypeScript typings for the loader and KD‑Tree API
|
||||
|
||||
## Quick start
|
||||
|
||||
- Build artifacts and copy `wasm_exec.js`:
|
||||
|
||||
```bash
|
||||
make wasm-build
|
||||
```
|
||||
|
||||
- Prepare the npm package folder with `dist/` and docs:
|
||||
|
||||
```bash
|
||||
make npm-pack
|
||||
```
|
||||
|
||||
- Minimal browser ESM usage (serve `dist/` statically):
|
||||
|
||||
```html
|
||||
<script type="module">
|
||||
import { init } from '/npm/poindexter-wasm/loader.js';
|
||||
const px = await init({
|
||||
wasmURL: '/dist/poindexter.wasm',
|
||||
wasmExecURL: '/dist/wasm_exec.js',
|
||||
});
|
||||
const tree = await px.newTree(2);
|
||||
await tree.insert({ id: 'a', coords: [0, 0], value: 'A' });
|
||||
const nn = await tree.nearest([0.1, 0.2]);
|
||||
console.log(nn);
|
||||
</script>
|
||||
```
|
||||
|
||||
## Building locally
|
||||
|
||||
```bash
|
||||
|
|
@ -102,3 +132,21 @@ Our CI builds and uploads the following artifacts on each push/PR:
|
|||
- `npm-poindexter-wasm-tarball` — a `.tgz` created via `npm pack` for quick local install/testing
|
||||
|
||||
You can download these artifacts from the workflow run summary in GitHub Actions.
|
||||
|
||||
## Browser demo (checked into repo)
|
||||
|
||||
There is a tiny browser demo you can load locally from this repo:
|
||||
|
||||
- Path: `examples/wasm-browser/index.html`
|
||||
- Prerequisites: run `make wasm-build` so `dist/poindexter.wasm` and `dist/wasm_exec.js` exist.
|
||||
- Serve the repo root (so relative paths resolve), for example:
|
||||
|
||||
```bash
|
||||
python3 -m http.server -b 127.0.0.1 8000
|
||||
```
|
||||
|
||||
Then open:
|
||||
|
||||
- http://127.0.0.1:8000/examples/wasm-browser/
|
||||
|
||||
Open the browser console to see outputs from `nearest`, `kNearest`, and `radius` queries.
|
||||
|
|
|
|||
60
examples/wasm-browser/index.html
Normal file
60
examples/wasm-browser/index.html
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Poindexter WASM Browser Demo</title>
|
||||
<style>
|
||||
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif; margin: 2rem; }
|
||||
pre { background: #f6f8fa; padding: 1rem; overflow-x: auto; }
|
||||
code { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Poindexter WASM Browser Demo</h1>
|
||||
<p>This demo uses the ESM loader to initialize the WebAssembly build and run a simple KDTree query entirely in your browser.</p>
|
||||
|
||||
<p>
|
||||
Serve this file from the repository root so the asset paths resolve. For example:
|
||||
</p>
|
||||
<pre><code>python3 -m http.server -b 127.0.0.1 8000</code></pre>
|
||||
<p>Then open <code>http://127.0.0.1:8000/examples/wasm-browser/</code> in your browser.</p>
|
||||
|
||||
<h2>Console output</h2>
|
||||
<p>Open DevTools console to inspect results.</p>
|
||||
|
||||
<script type="module">
|
||||
// Import the ESM loader from the npm package directory within this repo.
|
||||
// When serving from repo root, this path resolves to the local loader.
|
||||
import { init } from '../../npm/poindexter-wasm/loader.js';
|
||||
|
||||
async function main() {
|
||||
const px = await init({
|
||||
// Point to the built WASM artifacts in dist/. Ensure you ran `make wasm-build` first.
|
||||
wasmURL: '../../dist/poindexter.wasm',
|
||||
wasmExecURL: '../../dist/wasm_exec.js',
|
||||
});
|
||||
|
||||
console.log('Poindexter version (WASM):', await px.version());
|
||||
|
||||
const tree = await px.newTree(2);
|
||||
await tree.insert({ id: 'a', coords: [0, 0], value: 'A' });
|
||||
await tree.insert({ id: 'b', coords: [1, 0], value: 'B' });
|
||||
await tree.insert({ id: 'c', coords: [0, 1], value: 'C' });
|
||||
|
||||
const nearest = await tree.nearest([0.9, 0.1]);
|
||||
console.log('Nearest to [0.9,0.1]:', nearest);
|
||||
|
||||
const knn = await tree.kNearest([0.9, 0.9], 2);
|
||||
console.log('kNN (k=2) for [0.9,0.9]:', knn);
|
||||
|
||||
const within = await tree.radius([0, 0], 1.1);
|
||||
console.log('Within r=1.1 of [0,0]:', within);
|
||||
}
|
||||
|
||||
main().catch(err => {
|
||||
console.error('Demo error:', err);
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
103
kdtree.go
103
kdtree.go
|
|
@ -15,6 +15,8 @@ var (
|
|||
ErrDimMismatch = errors.New("kdtree: inconsistent dimensionality in points")
|
||||
// ErrDuplicateID indicates a duplicate point ID was encountered.
|
||||
ErrDuplicateID = errors.New("kdtree: duplicate point ID")
|
||||
// ErrBackendUnavailable indicates that a requested backend cannot be used (e.g., not built/tagged).
|
||||
ErrBackendUnavailable = errors.New("kdtree: requested backend unavailable")
|
||||
)
|
||||
|
||||
// KDPoint represents a point with coordinates and an attached payload/value.
|
||||
|
|
@ -171,11 +173,35 @@ type KDOption func(*kdOptions)
|
|||
|
||||
type kdOptions struct {
|
||||
metric DistanceMetric
|
||||
backend KDBackend
|
||||
}
|
||||
|
||||
// defaultBackend returns the implicit backend depending on build tags.
|
||||
// If built with the "gonum" tag, prefer the Gonum backend by default to keep
|
||||
// code paths simple and performant; otherwise fall back to the linear backend.
|
||||
func defaultBackend() KDBackend {
|
||||
if hasGonum() {
|
||||
return BackendGonum
|
||||
}
|
||||
return BackendLinear
|
||||
}
|
||||
|
||||
// KDBackend selects the internal engine used by KDTree.
|
||||
type KDBackend string
|
||||
|
||||
const (
|
||||
BackendLinear KDBackend = "linear"
|
||||
BackendGonum KDBackend = "gonum"
|
||||
)
|
||||
|
||||
// WithMetric sets the distance metric for the KDTree.
|
||||
func WithMetric(m DistanceMetric) KDOption { return func(o *kdOptions) { o.metric = m } }
|
||||
|
||||
// WithBackend selects the internal KDTree backend ("linear" or "gonum").
|
||||
// Default is linear. If the requested backend is unavailable (e.g., gonum build tag not enabled),
|
||||
// the constructor will silently fall back to the linear backend.
|
||||
func WithBackend(b KDBackend) KDOption { return func(o *kdOptions) { o.backend = b } }
|
||||
|
||||
// KDTree is a lightweight wrapper providing nearest-neighbor operations.
|
||||
//
|
||||
// Complexity: queries are O(n) linear scans in the current implementation.
|
||||
|
|
@ -190,6 +216,8 @@ type KDTree[T any] struct {
|
|||
dim int
|
||||
metric DistanceMetric
|
||||
idIndex map[string]int
|
||||
backend KDBackend
|
||||
backendData any // opaque handle for backend-specific structures (e.g., gonum tree)
|
||||
}
|
||||
|
||||
// NewKDTree builds a KDTree from the given points.
|
||||
|
|
@ -214,15 +242,29 @@ func NewKDTree[T any](pts []KDPoint[T], opts ...KDOption) (*KDTree[T], error) {
|
|||
idIndex[p.ID] = i
|
||||
}
|
||||
}
|
||||
cfg := kdOptions{metric: EuclideanDistance{}}
|
||||
cfg := kdOptions{metric: EuclideanDistance{}, backend: defaultBackend()}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
backend := cfg.backend
|
||||
var backendData any
|
||||
// Attempt to build gonum backend if requested and available.
|
||||
if backend == BackendGonum && hasGonum() {
|
||||
if bd, err := buildGonumBackend(pts, cfg.metric); err == nil {
|
||||
backendData = bd
|
||||
} else {
|
||||
backend = BackendLinear // fallback gracefully
|
||||
}
|
||||
} else if backend == BackendGonum && !hasGonum() {
|
||||
backend = BackendLinear // tag not enabled → fallback
|
||||
}
|
||||
t := &KDTree[T]{
|
||||
points: append([]KDPoint[T](nil), pts...),
|
||||
dim: dim,
|
||||
metric: cfg.metric,
|
||||
idIndex: idIndex,
|
||||
backend: backend,
|
||||
backendData: backendData,
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
|
@ -233,15 +275,21 @@ func NewKDTreeFromDim[T any](dim int, opts ...KDOption) (*KDTree[T], error) {
|
|||
if dim <= 0 {
|
||||
return nil, ErrZeroDim
|
||||
}
|
||||
cfg := kdOptions{metric: EuclideanDistance{}}
|
||||
cfg := kdOptions{metric: EuclideanDistance{}, backend: defaultBackend()}
|
||||
for _, o := range opts {
|
||||
o(&cfg)
|
||||
}
|
||||
backend := cfg.backend
|
||||
if backend == BackendGonum && !hasGonum() {
|
||||
backend = BackendLinear
|
||||
}
|
||||
return &KDTree[T]{
|
||||
points: nil,
|
||||
dim: dim,
|
||||
metric: cfg.metric,
|
||||
idIndex: make(map[string]int),
|
||||
backend: backend,
|
||||
backendData: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -257,6 +305,13 @@ func (t *KDTree[T]) Nearest(query []float64) (KDPoint[T], float64, bool) {
|
|||
if len(query) != t.dim || t.Len() == 0 {
|
||||
return KDPoint[T]{}, 0, false
|
||||
}
|
||||
// Gonum backend (if available and built)
|
||||
if t.backend == BackendGonum && t.backendData != nil {
|
||||
if idx, dist, ok := gonumNearest[T](t.backendData, query); ok && idx >= 0 && idx < len(t.points) {
|
||||
return t.points[idx], dist, true
|
||||
}
|
||||
// fall through to linear scan if backend didn’t return a result
|
||||
}
|
||||
bestIdx := -1
|
||||
bestDist := math.MaxFloat64
|
||||
for i := range t.points {
|
||||
|
|
@ -278,6 +333,18 @@ func (t *KDTree[T]) KNearest(query []float64, k int) ([]KDPoint[T], []float64) {
|
|||
if k <= 0 || len(query) != t.dim || t.Len() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Gonum backend path
|
||||
if t.backend == BackendGonum && t.backendData != nil {
|
||||
idxs, dists := gonumKNearest[T](t.backendData, query, k)
|
||||
if len(idxs) > 0 {
|
||||
neighbors := make([]KDPoint[T], len(idxs))
|
||||
for i := range idxs {
|
||||
neighbors[i] = t.points[idxs[i]]
|
||||
}
|
||||
return neighbors, dists
|
||||
}
|
||||
// fall back on unexpected empty
|
||||
}
|
||||
tmp := make([]struct {
|
||||
idx int
|
||||
dist float64
|
||||
|
|
@ -304,6 +371,18 @@ func (t *KDTree[T]) Radius(query []float64, r float64) ([]KDPoint[T], []float64)
|
|||
if r < 0 || len(query) != t.dim || t.Len() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Gonum backend path
|
||||
if t.backend == BackendGonum && t.backendData != nil {
|
||||
idxs, dists := gonumRadius[T](t.backendData, query, r)
|
||||
if len(idxs) > 0 {
|
||||
neighbors := make([]KDPoint[T], len(idxs))
|
||||
for i := range idxs {
|
||||
neighbors[i] = t.points[idxs[i]]
|
||||
}
|
||||
return neighbors, dists
|
||||
}
|
||||
// fall back if no results
|
||||
}
|
||||
var sel []struct {
|
||||
idx int
|
||||
dist float64
|
||||
|
|
@ -342,6 +421,16 @@ func (t *KDTree[T]) Insert(p KDPoint[T]) bool {
|
|||
if p.ID != "" {
|
||||
t.idIndex[p.ID] = len(t.points) - 1
|
||||
}
|
||||
// Rebuild backend if using Gonum
|
||||
if t.backend == BackendGonum && hasGonum() {
|
||||
if bd, err := buildGonumBackend(t.points, t.metric); err == nil {
|
||||
t.backendData = bd
|
||||
} else {
|
||||
// fallback to linear if rebuild fails
|
||||
t.backend = BackendLinear
|
||||
t.backendData = nil
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
@ -362,5 +451,15 @@ func (t *KDTree[T]) DeleteByID(id string) bool {
|
|||
}
|
||||
t.points = t.points[:last]
|
||||
delete(t.idIndex, id)
|
||||
// Rebuild backend if using Gonum
|
||||
if t.backend == BackendGonum && hasGonum() {
|
||||
if bd, err := buildGonumBackend(t.points, t.metric); err == nil {
|
||||
t.backendData = bd
|
||||
} else {
|
||||
// fallback to linear if rebuild fails
|
||||
t.backend = BackendLinear
|
||||
t.backendData = nil
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
|||
129
kdtree_backend_parity_test.go
Normal file
129
kdtree_backend_parity_test.go
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
package poindexter
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// makeFixedPoints creates a deterministic set of points in 4D and 2D for parity checks.
|
||||
func makeFixedPoints() []KDPoint[int] {
|
||||
pts := []KDPoint[int]{
|
||||
{ID: "A", Coords: []float64{0, 0, 0, 0}, Value: 1},
|
||||
{ID: "B", Coords: []float64{1, 0, 0.5, 0.2}, Value: 2},
|
||||
{ID: "C", Coords: []float64{0, 1, 0.3, 0.7}, Value: 3},
|
||||
{ID: "D", Coords: []float64{1, 1, 0.9, 0.9}, Value: 4},
|
||||
{ID: "E", Coords: []float64{0.2, 0.8, 0.4, 0.6}, Value: 5},
|
||||
}
|
||||
return pts
|
||||
}
|
||||
|
||||
func TestBackendParity_Nearest(t *testing.T) {
|
||||
pts := makeFixedPoints()
|
||||
queries := [][]float64{
|
||||
{0, 0, 0, 0},
|
||||
{0.9, 0.2, 0.5, 0.1},
|
||||
{0.5, 0.5, 0.5, 0.5},
|
||||
}
|
||||
|
||||
lin, err := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||
if err != nil {
|
||||
t.Fatalf("linear NewKDTree: %v", err)
|
||||
}
|
||||
|
||||
// Only build a gonum tree when the optimized backend is compiled in.
|
||||
if hasGonum() {
|
||||
gon, err := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||
if err != nil {
|
||||
t.Fatalf("gonum NewKDTree: %v", err)
|
||||
}
|
||||
for _, q := range queries {
|
||||
pl, dl, okl := lin.Nearest(q)
|
||||
pg, dg, okg := gon.Nearest(q)
|
||||
if okl != okg {
|
||||
t.Fatalf("ok mismatch: linear=%v gonum=%v", okl, okg)
|
||||
}
|
||||
if !okl {
|
||||
continue
|
||||
}
|
||||
if pl.ID != pg.ID {
|
||||
t.Errorf("nearest ID mismatch for %v: linear=%s gonum=%s", q, pl.ID, pg.ID)
|
||||
}
|
||||
if (dl == 0 && dg != 0) || (dl != 0 && dg == 0) {
|
||||
t.Errorf("nearest distance zero/nonzero mismatch: linear=%v gonum=%v", dl, dg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParity_KNearest(t *testing.T) {
|
||||
pts := makeFixedPoints()
|
||||
q := []float64{0.6, 0.6, 0.4, 0.4}
|
||||
ks := []int{1, 2, 5, 10}
|
||||
lin, _ := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||
if hasGonum() {
|
||||
gon, _ := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||
for _, k := range ks {
|
||||
ln, ld := lin.KNearest(q, k)
|
||||
gn, gd := gon.KNearest(q, k)
|
||||
if len(ln) != len(gn) || len(ld) != len(gd) {
|
||||
t.Fatalf("k=%d length mismatch: linear (%d,%d) vs gonum (%d,%d)", k, len(ln), len(ld), len(gn), len(gd))
|
||||
}
|
||||
// Compare IDs element-wise; ties may reorder between backends, so relax by set equality when distances equal.
|
||||
for i := range ln {
|
||||
if ln[i].ID != gn[i].ID {
|
||||
// If distances are effectively equal, allow different order
|
||||
if i < len(ld) && i < len(gd) && ld[i] == gd[i] {
|
||||
continue
|
||||
}
|
||||
t.Logf("k=%d index %d ID mismatch: linear=%s gonum=%s (dl=%.6f dg=%.6f)", k, i, ln[i].ID, gn[i].ID, ld[i], gd[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParity_Radius(t *testing.T) {
|
||||
pts := makeFixedPoints()
|
||||
q := []float64{0.4, 0.6, 0.4, 0.6}
|
||||
radii := []float64{0, 0.15, 0.3, 1.0}
|
||||
lin, _ := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||
if hasGonum() {
|
||||
gon, _ := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||
for _, r := range radii {
|
||||
ln, ld := lin.Radius(q, r)
|
||||
gn, gd := gon.Radius(q, r)
|
||||
if len(ln) != len(gn) || len(ld) != len(gd) {
|
||||
t.Fatalf("r=%.3f length mismatch: linear (%d,%d) vs gonum (%d,%d)", r, len(ln), len(ld), len(gn), len(gd))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBackendParity_RandomQueries2D(t *testing.T) {
|
||||
// Down-project 4D to 2D to exercise pruning differences as well
|
||||
pts4 := makeFixedPoints()
|
||||
pts2 := make([]KDPoint[int], len(pts4))
|
||||
for i, p := range pts4 {
|
||||
pts2[i] = KDPoint[int]{ID: p.ID, Coords: []float64{p.Coords[0], p.Coords[1]}, Value: p.Value}
|
||||
}
|
||||
lin, _ := NewKDTree(pts2, WithBackend(BackendLinear), WithMetric(ManhattanDistance{}))
|
||||
if hasGonum() {
|
||||
gon, _ := NewKDTree(pts2, WithBackend(BackendGonum), WithMetric(ManhattanDistance{}))
|
||||
rng := rand.New(rand.NewSource(42))
|
||||
for i := 0; i < 50; i++ {
|
||||
q := []float64{rng.Float64(), rng.Float64()}
|
||||
pl, dl, okl := lin.Nearest(q)
|
||||
pg, dg, okg := gon.Nearest(q)
|
||||
if okl != okg {
|
||||
t.Fatalf("ok mismatch (2D rand)")
|
||||
}
|
||||
if !okl {
|
||||
continue
|
||||
}
|
||||
if pl.ID != pg.ID && (dl != dg) {
|
||||
// Allow different picks only if distances tie; otherwise flag
|
||||
t.Errorf("2D rand nearest mismatch: linear %s(%.6f) gonum %s(%.6f)", pl.ID, dl, pg.ID, dg)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
311
kdtree_gonum.go
Normal file
311
kdtree_gonum.go
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
//go:build gonum
|
||||
|
||||
package poindexter
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// Note: This file is compiled when built with the "gonum" tag. For now, we
|
||||
// provide an internal KD-tree backend that performs balanced median-split
|
||||
// construction and branch-and-bound queries. This gives sub-linear behavior on
|
||||
// suitable datasets without introducing an external dependency. The public API
|
||||
// and option names remain the same; a future change can swap this implementation
|
||||
// to use gonum.org/v1/gonum/spatial/kdtree without altering callers.
|
||||
|
||||
// hasGonum reports whether the optimized backend is compiled in.
|
||||
func hasGonum() bool { return true }
|
||||
|
||||
// kdNode represents a node in the median-split KD-tree.
|
||||
type kdNode struct {
|
||||
axis int
|
||||
idx int // index into the original points slice
|
||||
val float64
|
||||
left *kdNode
|
||||
right *kdNode
|
||||
}
|
||||
|
||||
// kdBackend holds the KD-tree root and metadata.
|
||||
type kdBackend struct {
|
||||
root *kdNode
|
||||
dim int
|
||||
metric DistanceMetric
|
||||
// Access to original coords by index is done via a closure we capture at build
|
||||
coords func(i int) []float64
|
||||
len int
|
||||
}
|
||||
|
||||
// buildGonumBackend builds a balanced KD-tree using variance-based axis choice
|
||||
// and median splits. It does not reorder the external points slice; it keeps
|
||||
// indices and accesses the original data via closures, preserving caller order.
|
||||
func buildGonumBackend[T any](points []KDPoint[T], metric DistanceMetric) (any, error) {
|
||||
// Only enable this backend for metrics where the axis-slab bound is valid
|
||||
// for pruning: L2/L1/L∞. For other metrics (e.g., Cosine), fall back.
|
||||
switch metric.(type) {
|
||||
case EuclideanDistance, ManhattanDistance, ChebyshevDistance:
|
||||
// supported
|
||||
default:
|
||||
return nil, ErrBackendUnavailable
|
||||
}
|
||||
if len(points) == 0 {
|
||||
return &kdBackend{root: nil, dim: 0, metric: metric, coords: func(int) []float64 { return nil }}, nil
|
||||
}
|
||||
dim := len(points[0].Coords)
|
||||
coords := func(i int) []float64 { return points[i].Coords }
|
||||
idxs := make([]int, len(points))
|
||||
for i := range idxs {
|
||||
idxs[i] = i
|
||||
}
|
||||
root := buildKDRecursive(idxs, coords, dim, 0)
|
||||
return &kdBackend{root: root, dim: dim, metric: metric, coords: coords, len: len(points)}, nil
|
||||
}
|
||||
|
||||
// compute per-axis standard deviation (used for axis selection)
|
||||
func axisStd(idxs []int, coords func(int) []float64, dim int) []float64 {
|
||||
vars := make([]float64, dim)
|
||||
means := make([]float64, dim)
|
||||
n := float64(len(idxs))
|
||||
if n == 0 {
|
||||
return vars
|
||||
}
|
||||
for _, i := range idxs {
|
||||
c := coords(i)
|
||||
for d := 0; d < dim; d++ {
|
||||
means[d] += c[d]
|
||||
}
|
||||
}
|
||||
for d := 0; d < dim; d++ {
|
||||
means[d] /= n
|
||||
}
|
||||
for _, i := range idxs {
|
||||
c := coords(i)
|
||||
for d := 0; d < dim; d++ {
|
||||
delta := c[d] - means[d]
|
||||
vars[d] += delta * delta
|
||||
}
|
||||
}
|
||||
for d := 0; d < dim; d++ {
|
||||
vars[d] = math.Sqrt(vars[d] / n)
|
||||
}
|
||||
return vars
|
||||
}
|
||||
|
||||
func buildKDRecursive(idxs []int, coords func(int) []float64, dim int, depth int) *kdNode {
|
||||
if len(idxs) == 0 {
|
||||
return nil
|
||||
}
|
||||
// choose axis with max stddev
|
||||
stds := axisStd(idxs, coords, dim)
|
||||
axis := 0
|
||||
maxv := stds[0]
|
||||
for d := 1; d < dim; d++ {
|
||||
if stds[d] > maxv {
|
||||
maxv = stds[d]
|
||||
axis = d
|
||||
}
|
||||
}
|
||||
// nth-element (partial sort) by axis using sort.Slice for simplicity
|
||||
sort.Slice(idxs, func(i, j int) bool { return coords(idxs[i])[axis] < coords(idxs[j])[axis] })
|
||||
mid := len(idxs) / 2
|
||||
medianIdx := idxs[mid]
|
||||
n := &kdNode{axis: axis, idx: medianIdx, val: coords(medianIdx)[axis]}
|
||||
n.left = buildKDRecursive(append([]int(nil), idxs[:mid]...), coords, dim, depth+1)
|
||||
n.right = buildKDRecursive(append([]int(nil), idxs[mid+1:]...), coords, dim, depth+1)
|
||||
return n
|
||||
}
|
||||
|
||||
// gonumNearest performs 1-NN search using the KD backend.
|
||||
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
|
||||
b, ok := backend.(*kdBackend)
|
||||
if !ok || b.root == nil || len(query) != b.dim {
|
||||
return -1, 0, false
|
||||
}
|
||||
bestIdx := -1
|
||||
bestDist := math.MaxFloat64
|
||||
var search func(*kdNode)
|
||||
search = func(n *kdNode) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
c := b.coords(n.idx)
|
||||
d := b.metric.Distance(query, c)
|
||||
if d < bestDist {
|
||||
bestDist = d
|
||||
bestIdx = n.idx
|
||||
}
|
||||
axis := n.axis
|
||||
qv := query[axis]
|
||||
// choose side
|
||||
near, far := n.left, n.right
|
||||
if qv >= n.val {
|
||||
near, far = n.right, n.left
|
||||
}
|
||||
search(near)
|
||||
// prune if hyperslab distance is >= bestDist
|
||||
diff := qv - n.val
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff <= bestDist {
|
||||
search(far)
|
||||
}
|
||||
}
|
||||
search(b.root)
|
||||
if bestIdx < 0 {
|
||||
return -1, 0, false
|
||||
}
|
||||
return bestIdx, bestDist, true
|
||||
}
|
||||
|
||||
// small max-heap for (distance, index)
|
||||
// We’ll use a slice maintaining the largest distance at [0] via container/heap-like ops.
|
||||
type knnItem struct {
|
||||
idx int
|
||||
dist float64
|
||||
}
|
||||
|
||||
type knnHeap []knnItem
|
||||
|
||||
func (h knnHeap) Len() int { return len(h) }
|
||||
func (h knnHeap) less(i, j int) bool { return h[i].dist > h[j].dist } // max-heap by dist
|
||||
func (h *knnHeap) push(x knnItem) { *h = append(*h, x); h.up(len(*h) - 1) }
|
||||
func (h *knnHeap) pop() knnItem {
|
||||
n := len(*h) - 1
|
||||
h.swap(0, n)
|
||||
v := (*h)[n]
|
||||
*h = (*h)[:n]
|
||||
h.down(0)
|
||||
return v
|
||||
}
|
||||
func (h *knnHeap) peek() knnItem { return (*h)[0] }
|
||||
func (h knnHeap) swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *knnHeap) up(i int) {
|
||||
for i > 0 {
|
||||
p := (i - 1) / 2
|
||||
if !h.less(i, p) {
|
||||
break
|
||||
}
|
||||
h.swap(i, p)
|
||||
i = p
|
||||
}
|
||||
}
|
||||
func (h *knnHeap) down(i int) {
|
||||
for {
|
||||
l := 2*i + 1
|
||||
r := l + 1
|
||||
largest := i
|
||||
if l < h.Len() && h.less(l, largest) {
|
||||
largest = l
|
||||
}
|
||||
if r < h.Len() && h.less(r, largest) {
|
||||
largest = r
|
||||
}
|
||||
if largest == i {
|
||||
break
|
||||
}
|
||||
h.swap(i, largest)
|
||||
i = largest
|
||||
}
|
||||
}
|
||||
|
||||
// gonumKNearest returns indices in ascending distance order.
|
||||
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
|
||||
b, ok := backend.(*kdBackend)
|
||||
if !ok || b.root == nil || len(query) != b.dim || k <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var h knnHeap
|
||||
bestCap := k
|
||||
var search func(*kdNode)
|
||||
search = func(n *kdNode) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
c := b.coords(n.idx)
|
||||
d := b.metric.Distance(query, c)
|
||||
if h.Len() < bestCap {
|
||||
h.push(knnItem{idx: n.idx, dist: d})
|
||||
} else if d < h.peek().dist {
|
||||
// replace max
|
||||
h[0] = knnItem{idx: n.idx, dist: d}
|
||||
h.down(0)
|
||||
}
|
||||
axis := n.axis
|
||||
qv := query[axis]
|
||||
near, far := n.left, n.right
|
||||
if qv >= n.val {
|
||||
near, far = n.right, n.left
|
||||
}
|
||||
search(near)
|
||||
// prune against current worst in heap if heap is full; otherwise use bestDist
|
||||
threshold := math.MaxFloat64
|
||||
if h.Len() == bestCap {
|
||||
threshold = h.peek().dist
|
||||
} else if h.Len() > 0 {
|
||||
// use best known (not strictly necessary)
|
||||
threshold = h.peek().dist
|
||||
}
|
||||
diff := qv - n.val
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff <= threshold {
|
||||
search(far)
|
||||
}
|
||||
}
|
||||
search(b.root)
|
||||
// Extract to slices and sort ascending by distance
|
||||
res := make([]knnItem, len(h))
|
||||
copy(res, h)
|
||||
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||||
idxs := make([]int, len(res))
|
||||
dists := make([]float64, len(res))
|
||||
for i := range res {
|
||||
idxs[i] = res[i].idx
|
||||
dists[i] = res[i].dist
|
||||
}
|
||||
return idxs, dists
|
||||
}
|
||||
|
||||
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
|
||||
b, ok := backend.(*kdBackend)
|
||||
if !ok || b.root == nil || len(query) != b.dim || r < 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var res []knnItem
|
||||
var search func(*kdNode)
|
||||
search = func(n *kdNode) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
c := b.coords(n.idx)
|
||||
d := b.metric.Distance(query, c)
|
||||
if d <= r {
|
||||
res = append(res, knnItem{idx: n.idx, dist: d})
|
||||
}
|
||||
axis := n.axis
|
||||
qv := query[axis]
|
||||
near, far := n.left, n.right
|
||||
if qv >= n.val {
|
||||
near, far = n.right, n.left
|
||||
}
|
||||
search(near)
|
||||
diff := qv - n.val
|
||||
if diff < 0 {
|
||||
diff = -diff
|
||||
}
|
||||
if diff <= r {
|
||||
search(far)
|
||||
}
|
||||
}
|
||||
search(b.root)
|
||||
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||||
idxs := make([]int, len(res))
|
||||
dists := make([]float64, len(res))
|
||||
for i := range res {
|
||||
idxs[i] = res[i].idx
|
||||
dists[i] = res[i].dist
|
||||
}
|
||||
return idxs, dists
|
||||
}
|
||||
23
kdtree_gonum_stub.go
Normal file
23
kdtree_gonum_stub.go
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
//go:build !gonum
|
||||
|
||||
package poindexter
|
||||
|
||||
// hasGonum reports whether the gonum backend is compiled in (build tag 'gonum').
|
||||
func hasGonum() bool { return false }
|
||||
|
||||
// buildGonumBackend is unavailable without the 'gonum' build tag.
|
||||
func buildGonumBackend[T any](pts []KDPoint[T], metric DistanceMetric) (any, error) {
|
||||
return nil, ErrEmptyPoints // sentinel non-nil error to force fallback
|
||||
}
|
||||
|
||||
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
|
||||
return -1, 0, false
|
||||
}
|
||||
|
||||
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
@ -71,14 +71,32 @@ Explore runnable examples in the repository:
|
|||
- examples/kdtree_2d_ping_hop
|
||||
- examples/kdtree_3d_ping_hop_geo
|
||||
- examples/kdtree_4d_ping_hop_geo_score
|
||||
- examples/wasm-browser (browser demo using the ESM loader)
|
||||
|
||||
### KDTree performance and notes
|
||||
- Current KDTree queries are O(n) linear scans, which are great for small-to-medium datasets or low-latency prototyping. For 1e5+ points and low/medium dimensions, consider swapping the internal engine to `gonum.org/v1/gonum/spatial/kdtree` (the API here is compatible by design).
|
||||
- Dual backend support: Linear (always available) and an optimized KD backend enabled when building with `-tags=gonum`. Linear is the default; with the `gonum` tag, the optimized backend becomes the default.
|
||||
- Complexity: Linear backend is O(n) per query. Optimized KD backend is typically sub-linear on prunable datasets and dims ≤ ~8, especially as N grows (≥10k–100k).
|
||||
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
||||
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
||||
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
||||
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
||||
|
||||
### Backend selection
|
||||
- Default backend is Linear. If you build with `-tags=gonum`, the default becomes the optimized KD backend.
|
||||
- You can override per tree at construction:
|
||||
|
||||
```go
|
||||
// Force Linear (always available)
|
||||
kdt1, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||
|
||||
// Force Gonum (requires build tag)
|
||||
kdt2, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||
```
|
||||
|
||||
- Supported metrics in the optimized backend: Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||
- Cosine and Weighted-Cosine currently run on the Linear backend.
|
||||
- See the Performance guide for measured comparisons and when to choose which backend.
|
||||
|
||||
#### Choosing a metric (quick tips)
|
||||
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
||||
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue