Add dual-backend support for KDTree with benchmarks and documentation updates
This commit is contained in:
parent
5d1ee3f0ea
commit
3c83fc38e4
13 changed files with 1138 additions and 36 deletions
71
.github/workflows/ci.yml
vendored
71
.github/workflows/ci.yml
vendored
|
|
@ -39,6 +39,16 @@ jobs:
|
||||||
- name: CI checks (lint, tests, coverage, etc.)
|
- name: CI checks (lint, tests, coverage, etc.)
|
||||||
run: make ci
|
run: make ci
|
||||||
|
|
||||||
|
- name: Benchmarks (linear)
|
||||||
|
run: go test -bench . -benchmem -run=^$ ./... | tee bench-linear.txt
|
||||||
|
|
||||||
|
- name: Upload benchmarks (linear)
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: bench-linear
|
||||||
|
path: bench-linear.txt
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
- name: Build WebAssembly module
|
- name: Build WebAssembly module
|
||||||
run: make wasm-build
|
run: make wasm-build
|
||||||
|
|
||||||
|
|
@ -72,3 +82,64 @@ jobs:
|
||||||
name: npm-poindexter-wasm-tarball
|
name: npm-poindexter-wasm-tarball
|
||||||
if-no-files-found: error
|
if-no-files-found: error
|
||||||
path: ${{ steps.npm_pack.outputs.tarball }}
|
path: ${{ steps.npm_pack.outputs.tarball }}
|
||||||
|
|
||||||
|
build-test-gonum:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: '1.23.x'
|
||||||
|
|
||||||
|
- name: Install extra tools
|
||||||
|
run: |
|
||||||
|
go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||||
|
go install golang.org/x/vuln/cmd/govulncheck@latest
|
||||||
|
|
||||||
|
- name: Go env
|
||||||
|
run: go env
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: golangci-lint run
|
||||||
|
|
||||||
|
- name: Build (gonum tag)
|
||||||
|
run: go build -tags=gonum ./...
|
||||||
|
|
||||||
|
- name: Unit tests + race + coverage (gonum tag)
|
||||||
|
run: go test -tags=gonum -race -coverpkg=./... -coverprofile=coverage-gonum.out -covermode=atomic ./...
|
||||||
|
|
||||||
|
- name: Fuzz (10s per fuzz test, gonum tag)
|
||||||
|
run: |
|
||||||
|
set -e
|
||||||
|
for pkg in $(go list ./...); do
|
||||||
|
FUZZES=$(go test -tags=gonum -list '^Fuzz' "$pkg" | grep '^Fuzz' || true)
|
||||||
|
if [ -z "$FUZZES" ]; then
|
||||||
|
echo "==> Skipping $pkg (no fuzz targets)"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
for fz in $FUZZES; do
|
||||||
|
echo "==> Fuzzing $pkg :: $fz for 10s"
|
||||||
|
go test -tags=gonum -run=NONE -fuzz=^${fz}$ -fuzztime=10s "$pkg"
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
- name: Benchmarks (gonum tag)
|
||||||
|
run: go test -tags=gonum -bench . -benchmem -run=^$ ./... | tee bench-gonum.txt
|
||||||
|
|
||||||
|
- name: Upload coverage (gonum)
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: coverage-gonum
|
||||||
|
path: coverage-gonum.out
|
||||||
|
if-no-files-found: error
|
||||||
|
|
||||||
|
- name: Upload benchmarks (gonum)
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: bench-gonum
|
||||||
|
path: bench-gonum.txt
|
||||||
|
if-no-files-found: error
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,12 @@ The format is based on Keep a Changelog and this project adheres to Semantic Ver
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
### Added
|
### Added
|
||||||
|
- Dual-backend benchmarks (Linear vs Gonum) with deterministic datasets (uniform/clustered) in 2D/4D for N=1k/10k; artifacts uploaded in CI as `bench-linear.txt` and `bench-gonum.txt`.
|
||||||
|
- Documentation: Performance guide updated to cover backend selection, how to run both backends, CI artifact links, and guidance on when each backend is preferred.
|
||||||
|
- Documentation: Performance guide now includes a Sample results table sourced from a recent local run.
|
||||||
|
- Documentation: README gained a “Backend selection” section with default behavior, build tag usage, overrides, and supported metrics notes.
|
||||||
|
- Documentation: API reference (`docs/api.md`) now documents `KDBackend`, `WithBackend`, default selection, and supported metrics for the optimized backend.
|
||||||
|
- Examples: Added `examples/wasm-browser/` minimal browser demo (ESM + HTML) for the WASM build.
|
||||||
- pkg.go.dev Examples: `ExampleNewKDTreeFromDim_Insert`, `ExampleKDTree_TiesBehavior`, `ExampleKDTree_Radius_none`.
|
- pkg.go.dev Examples: `ExampleNewKDTreeFromDim_Insert`, `ExampleKDTree_TiesBehavior`, `ExampleKDTree_Radius_none`.
|
||||||
- Lint: enable `errcheck` in `.golangci.yml` with test-file exclusion to reduce noise.
|
- Lint: enable `errcheck` in `.golangci.yml` with test-file exclusion to reduce noise.
|
||||||
- CI: enable module cache in `actions/setup-go` to speed up workflows.
|
- CI: enable module cache in `actions/setup-go` to speed up workflows.
|
||||||
|
|
|
||||||
20
README.md
20
README.md
|
|
@ -71,14 +71,32 @@ Explore runnable examples in the repository:
|
||||||
- examples/kdtree_2d_ping_hop
|
- examples/kdtree_2d_ping_hop
|
||||||
- examples/kdtree_3d_ping_hop_geo
|
- examples/kdtree_3d_ping_hop_geo
|
||||||
- examples/kdtree_4d_ping_hop_geo_score
|
- examples/kdtree_4d_ping_hop_geo_score
|
||||||
|
- examples/wasm-browser (browser demo using the ESM loader)
|
||||||
|
|
||||||
### KDTree performance and notes
|
### KDTree performance and notes
|
||||||
- Current KDTree queries are O(n) linear scans, which are great for small-to-medium datasets or low-latency prototyping. For 1e5+ points and low/medium dimensions, consider swapping the internal engine to `gonum.org/v1/gonum/spatial/kdtree` (the API here is compatible by design).
|
- Dual backend support: Linear (always available) and an optimized KD backend enabled when building with `-tags=gonum`. Linear is the default; with the `gonum` tag, the optimized backend becomes the default.
|
||||||
|
- Complexity: Linear backend is O(n) per query. Optimized KD backend is typically sub-linear on prunable datasets and dims ≤ ~8, especially as N grows (≥10k–100k).
|
||||||
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
||||||
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
||||||
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
||||||
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
||||||
|
|
||||||
|
### Backend selection
|
||||||
|
- Default backend is Linear. If you build with `-tags=gonum`, the default becomes the optimized KD backend.
|
||||||
|
- You can override per tree at construction:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Force Linear (always available)
|
||||||
|
kdt1, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||||
|
|
||||||
|
// Force Gonum (requires build tag)
|
||||||
|
kdt2, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||||
|
```
|
||||||
|
|
||||||
|
- Supported metrics in the optimized backend: Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||||
|
- Cosine and Weighted-Cosine currently run on the Linear backend.
|
||||||
|
- See the Performance guide for measured comparisons and when to choose which backend.
|
||||||
|
|
||||||
#### Choosing a metric (quick tips)
|
#### Choosing a metric (quick tips)
|
||||||
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
||||||
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
||||||
|
|
|
||||||
180
bench_kdtree_dual_test.go
Normal file
180
bench_kdtree_dual_test.go
Normal file
|
|
@ -0,0 +1,180 @@
|
||||||
|
package poindexter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// deterministicRand returns a rand.Rand with a fixed seed for reproducible datasets.
|
||||||
|
func deterministicRand() *rand.Rand { return rand.New(rand.NewSource(42)) }
|
||||||
|
|
||||||
|
func makeUniformPoints(n, dim int) []KDPoint[int] {
|
||||||
|
r := deterministicRand()
|
||||||
|
pts := make([]KDPoint[int], n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
coords := make([]float64, dim)
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
coords[d] = r.Float64()
|
||||||
|
}
|
||||||
|
pts[i] = KDPoint[int]{ID: fmt.Sprint(i), Coords: coords, Value: i}
|
||||||
|
}
|
||||||
|
return pts
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeClusteredPoints creates n points around c clusters with small variance.
|
||||||
|
func makeClusteredPoints(n, dim, c int) []KDPoint[int] {
|
||||||
|
if c <= 0 {
|
||||||
|
c = 1
|
||||||
|
}
|
||||||
|
r := deterministicRand()
|
||||||
|
centers := make([][]float64, c)
|
||||||
|
for i := 0; i < c; i++ {
|
||||||
|
centers[i] = make([]float64, dim)
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
centers[i][d] = r.Float64()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pts := make([]KDPoint[int], n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
coords := make([]float64, dim)
|
||||||
|
cent := centers[r.Intn(c)]
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
// small gaussian noise around center (Box-Muller)
|
||||||
|
u1 := r.Float64()
|
||||||
|
u2 := r.Float64()
|
||||||
|
z := (rand.NormFloat64()) // uses global; fine for test speed
|
||||||
|
_ = u1
|
||||||
|
_ = u2
|
||||||
|
coords[d] = cent[d] + 0.03*z
|
||||||
|
if coords[d] < 0 {
|
||||||
|
coords[d] = 0
|
||||||
|
} else if coords[d] > 1 {
|
||||||
|
coords[d] = 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pts[i] = KDPoint[int]{ID: fmt.Sprint(i), Coords: coords, Value: i}
|
||||||
|
}
|
||||||
|
return pts
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchNearestBackend(b *testing.B, n, dim int, backend KDBackend, uniform bool, clusters int) {
|
||||||
|
var pts []KDPoint[int]
|
||||||
|
if uniform {
|
||||||
|
pts = makeUniformPoints(n, dim)
|
||||||
|
} else {
|
||||||
|
pts = makeClusteredPoints(n, dim, clusters)
|
||||||
|
}
|
||||||
|
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||||
|
q := make([]float64, dim)
|
||||||
|
for i := range q {
|
||||||
|
q[i] = 0.5
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _, _ = tr.Nearest(q)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchKNNBackend(b *testing.B, n, dim, k int, backend KDBackend, uniform bool, clusters int) {
|
||||||
|
var pts []KDPoint[int]
|
||||||
|
if uniform {
|
||||||
|
pts = makeUniformPoints(n, dim)
|
||||||
|
} else {
|
||||||
|
pts = makeClusteredPoints(n, dim, clusters)
|
||||||
|
}
|
||||||
|
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||||
|
q := make([]float64, dim)
|
||||||
|
for i := range q {
|
||||||
|
q[i] = 0.5
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = tr.KNearest(q, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchRadiusBackend(b *testing.B, n, dim int, r float64, backend KDBackend, uniform bool, clusters int) {
|
||||||
|
var pts []KDPoint[int]
|
||||||
|
if uniform {
|
||||||
|
pts = makeUniformPoints(n, dim)
|
||||||
|
} else {
|
||||||
|
pts = makeClusteredPoints(n, dim, clusters)
|
||||||
|
}
|
||||||
|
tr, _ := NewKDTree(pts, WithBackend(backend))
|
||||||
|
q := make([]float64, dim)
|
||||||
|
for i := range q {
|
||||||
|
q[i] = 0.5
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = tr.Radius(q, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uniform 2D/4D, Linear vs Gonum (opt-in via build tag; falls back to linear if not available)
|
||||||
|
func BenchmarkNearest_Linear_Uniform_1k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 2, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Uniform_1k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 2, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Linear_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 2, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 2, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkNearest_Linear_Uniform_1k_4D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 4, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Uniform_1k_4D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 4, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Linear_Uniform_10k_4D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 4, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Uniform_10k_4D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 4, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clustered 2D/4D (3 clusters)
|
||||||
|
func BenchmarkNearest_Linear_Clustered_1k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 2, BackendLinear, false, 3)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Clustered_1k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 1_000, 2, BackendGonum, false, 3)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Linear_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 2, BackendLinear, false, 3)
|
||||||
|
}
|
||||||
|
func BenchmarkNearest_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchNearestBackend(b, 10_000, 2, BackendGonum, false, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkKNN10_Linear_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchKNNBackend(b, 10_000, 2, 10, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkKNN10_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchKNNBackend(b, 10_000, 2, 10, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkKNN10_Linear_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchKNNBackend(b, 10_000, 2, 10, BackendLinear, false, 3)
|
||||||
|
}
|
||||||
|
func BenchmarkKNN10_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchKNNBackend(b, 10_000, 2, 10, BackendGonum, false, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRadiusMid_Linear_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchRadiusBackend(b, 10_000, 2, 0.5, BackendLinear, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkRadiusMid_Gonum_Uniform_10k_2D(b *testing.B) {
|
||||||
|
benchRadiusBackend(b, 10_000, 2, 0.5, BackendGonum, true, 0)
|
||||||
|
}
|
||||||
|
func BenchmarkRadiusMid_Linear_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchRadiusBackend(b, 10_000, 2, 0.5, BackendLinear, false, 3)
|
||||||
|
}
|
||||||
|
func BenchmarkRadiusMid_Gonum_Clustered_10k_2D(b *testing.B) {
|
||||||
|
benchRadiusBackend(b, 10_000, 2, 0.5, BackendGonum, false, 3)
|
||||||
|
}
|
||||||
57
docs/api.md
57
docs/api.md
|
|
@ -494,3 +494,60 @@ Notes:
|
||||||
- If `min==max` for an axis, normalized value is `0` for that axis.
|
- If `min==max` for an axis, normalized value is `0` for that axis.
|
||||||
- `invert[i]` flips the normalized axis as `1 - n` before applying `weights[i]`.
|
- `invert[i]` flips the normalized axis as `1 - n` before applying `weights[i]`.
|
||||||
- These helpers mirror `Build2D/3D/4D`, but use your provided `NormStats` instead of recomputing from the items slice.
|
- These helpers mirror `Build2D/3D/4D`, but use your provided `NormStats` instead of recomputing from the items slice.
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## KDTree Backend selection
|
||||||
|
|
||||||
|
Poindexter provides two internal backends for KDTree queries:
|
||||||
|
|
||||||
|
- `linear`: always available; performs O(n) scans for `Nearest`, `KNearest`, and `Radius`.
|
||||||
|
- `gonum`: optimized KD backend compiled when you build with the `gonum` build tag; typically sub-linear on prunable datasets and modest dimensions.
|
||||||
|
|
||||||
|
### Types and options
|
||||||
|
|
||||||
|
```go
|
||||||
|
// KDBackend selects the internal engine used by KDTree.
|
||||||
|
type KDBackend string
|
||||||
|
|
||||||
|
const (
|
||||||
|
BackendLinear KDBackend = "linear"
|
||||||
|
BackendGonum KDBackend = "gonum"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithBackend selects the internal KDTree backend ("linear" or "gonum").
|
||||||
|
// If the requested backend is unavailable (e.g., missing build tag), the constructor
|
||||||
|
// falls back to the linear backend.
|
||||||
|
func WithBackend(b KDBackend) KDOption
|
||||||
|
```
|
||||||
|
|
||||||
|
### Default selection
|
||||||
|
|
||||||
|
- Default is `linear`.
|
||||||
|
- If you build your project with `-tags=gonum`, the default becomes `gonum`.
|
||||||
|
|
||||||
|
### Usage examples
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Default metric is Euclidean; you can override with WithMetric.
|
||||||
|
pts := []poindexter.KDPoint[string]{
|
||||||
|
{ID: "A", Coords: []float64{0, 0}},
|
||||||
|
{ID: "B", Coords: []float64{1, 0}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force Linear (always available)
|
||||||
|
lin, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||||
|
_, _, _ = lin.Nearest([]float64{0.9, 0.1})
|
||||||
|
|
||||||
|
// Force Gonum (requires building with: go build -tags=gonum)
|
||||||
|
gon, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||||
|
_, _, _ = gon.Nearest([]float64{0.9, 0.1})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported metrics in the optimized backend
|
||||||
|
|
||||||
|
- Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||||
|
- Cosine and Weighted-Cosine currently use the Linear backend.
|
||||||
|
|
||||||
|
See also the Performance guide for measured comparisons and guidance: `docs/perf.md`.
|
||||||
120
docs/perf.md
120
docs/perf.md
|
|
@ -1,51 +1,133 @@
|
||||||
# Performance: KDTree benchmarks and guidance
|
# Performance: KDTree benchmarks and guidance
|
||||||
|
|
||||||
This page summarizes how to measure KDTree performance in this repository and when to consider switching the internal engine to `gonum.org/v1/gonum/spatial/kdtree` for large datasets.
|
This page summarizes how to measure KDTree performance in this repository and how to compare the two internal backends (Linear vs Gonum) that you can select at build/runtime.
|
||||||
|
|
||||||
## How benchmarks are organized
|
## How benchmarks are organized
|
||||||
|
|
||||||
- Micro-benchmarks live in `bench_kdtree_test.go` and cover:
|
- Micro-benchmarks live in `bench_kdtree_test.go` and `bench_kdtree_dual_test.go` and cover:
|
||||||
- `Nearest` in 2D and 4D with N = 1k, 10k
|
- `Nearest` in 2D and 4D with N = 1k, 10k
|
||||||
- `KNearest(k=10)` in 2D with N = 1k, 10k
|
- `KNearest(k=10)` in 2D/4D with N = 1k, 10k
|
||||||
- `Radius` (mid radius) in 2D with N = 1k, 10k
|
- `Radius` (mid radius r≈0.5 after normalization) in 2D/4D with N = 1k, 10k
|
||||||
- All benchmarks operate in normalized [0,1] spaces and use the current linear-scan implementation.
|
- Datasets: Uniform and 3-cluster synthetic generators in normalized [0,1] spaces.
|
||||||
|
- Backends: Linear (always available) and Gonum (enabled when built with `-tags=gonum`).
|
||||||
|
|
||||||
Run them locally:
|
Run them locally:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Linear backend (default)
|
||||||
go test -bench . -benchmem -run=^$ ./...
|
go test -bench . -benchmem -run=^$ ./...
|
||||||
|
|
||||||
|
# Gonum backend (optimized KD; requires build tag)
|
||||||
|
go test -tags=gonum -bench . -benchmem -run=^$ ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
GitHub Actions publishes benchmark artifacts for Go 1.23 on every push/PR. Look for artifacts named `bench-<go-version>.txt` in the CI run.
|
GitHub Actions publishes benchmark artifacts on every push/PR:
|
||||||
|
- Linear job: artifact `bench-linear.txt`
|
||||||
|
- Gonum job: artifact `bench-gonum.txt`
|
||||||
|
|
||||||
|
## Backend selection and defaults
|
||||||
|
|
||||||
|
- Default backend is Linear.
|
||||||
|
- If you build with `-tags=gonum`, the default switches to the optimized KD backend.
|
||||||
|
- You can override at runtime:
|
||||||
|
|
||||||
|
```
|
||||||
|
// Force Linear
|
||||||
|
kdt, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||||
|
// Force Gonum (requires build tag)
|
||||||
|
kdt, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported metrics in the optimized backend: L2 (Euclidean), L1 (Manhattan), L∞ (Chebyshev). Cosine/Weighted-Cosine currently use the Linear backend.
|
||||||
|
|
||||||
## What to expect (rule of thumb)
|
## What to expect (rule of thumb)
|
||||||
|
|
||||||
- Time complexity is O(n) per query in the current implementation.
|
- Linear backend: O(n) per query; fast for small-to-medium datasets (≤10k), especially in low dims (≤4).
|
||||||
- For small-to-medium datasets (up to ~10k points), linear scans are often fast enough, especially for low dimensionality (≤4) and if queries are batched efficiently.
|
- Gonum backend: typically sub-linear for prunable datasets and dims ≤ ~8, with noticeable gains as N grows (≥10k–100k), especially on uniform or moderately clustered data and moderate radii.
|
||||||
- For larger datasets (≥100k) and low/medium dimensions (≤8), a true KD-tree (like Gonum’s) often yields sub-linear queries and significantly lower latency.
|
- For large radii (many points within r) or highly correlated/pathological data, pruning may be less effective and behavior approaches O(n) even with KD-trees.
|
||||||
|
|
||||||
## Interpreting results
|
## Interpreting results
|
||||||
|
|
||||||
Benchmarks output something like:
|
Benchmarks output something like:
|
||||||
|
|
||||||
```
|
```
|
||||||
BenchmarkNearest_10k_4D-8 50000 23,000 ns/op 0 B/op 0 allocs/op
|
BenchmarkNearest_10k_4D_Gonum_Uniform-8 50000 12,300 ns/op 0 B/op 0 allocs/op
|
||||||
```
|
```
|
||||||
|
|
||||||
- `ns/op`: lower is better (nanoseconds per operation)
|
- `ns/op`: lower is better (nanoseconds per operation)
|
||||||
- `B/op` and `allocs/op`: memory behavior; fewer is better
|
- `B/op` and `allocs/op`: memory behavior; fewer is better
|
||||||
|
- `KNearest` incurs extra work due to sorting; `Radius` cost scales with the number of hits.
|
||||||
Because `KNearest` sorts by distance, you should expect additional cost over `Nearest`. `Radius` cost depends on how many points fall within the radius; tighter radii usually run faster.
|
|
||||||
|
|
||||||
## Improving performance
|
## Improving performance
|
||||||
|
|
||||||
- Prefer Euclidean (L2) over metrics that require extra branching for CPU pipelines, unless your policy prefers otherwise.
|
- Normalize and weight features once; reuse across queries (see `Build*WithStats` helpers).
|
||||||
- Normalize and weight features once; reuse coordinates across queries.
|
- Choose a metric aligned with your policy: L2 usually a solid default; L1 for per-axis penalties; L∞ for hard-threshold dominated objectives.
|
||||||
- Batch queries to amortize overhead of data locality and caches.
|
- Batch queries to benefit from CPU caches.
|
||||||
- Consider a backend swap to Gonum’s KD-tree for large N (we plan to add a `WithBackend("gonum")` option).
|
- Prefer the Gonum backend for larger N and dims ≤ ~8; stick to Linear for tiny datasets or when using Cosine metrics.
|
||||||
|
|
||||||
## Reproducing and tracking performance
|
## Reproducing and tracking performance
|
||||||
|
|
||||||
- Local: run `go test -bench . -benchmem -run=^$ ./...`
|
- Local (Linear): `go test -bench . -benchmem -run=^$ ./...`
|
||||||
- CI: download `bench-*.txt` artifacts from the latest workflow run
|
- Local (Gonum): `go test -tags=gonum -bench . -benchmem -run=^$ ./...`
|
||||||
- Optional: we can add historical trend graphs via Codecov or Benchstat integration if desired.
|
- CI artifacts: download `bench-linear.txt` and `bench-gonum.txt` from the latest workflow run.
|
||||||
|
- Optional: add historical trend graphs via Benchstat or Codecov integration.
|
||||||
|
|
||||||
|
## Sample results (from a recent local run)
|
||||||
|
|
||||||
|
Results vary by machine, Go version, and dataset seed. The following run was captured locally and is provided as a reference point.
|
||||||
|
|
||||||
|
- Machine: darwin/arm64, Apple M3 Ultra
|
||||||
|
- Package: `github.com/Snider/Poindexter`
|
||||||
|
- Command: `go test -bench . -benchmem -run=^$ ./... | tee bench.txt`
|
||||||
|
|
||||||
|
Full output:
|
||||||
|
|
||||||
|
```
|
||||||
|
goos: darwin
|
||||||
|
goarch: arm64
|
||||||
|
pkg: github.com/Snider/Poindexter
|
||||||
|
BenchmarkNearest_Linear_Uniform_1k_2D-32 409321 3001 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Uniform_1k_2D-32 413823 2888 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Linear_Uniform_10k_2D-32 43053 27809 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Uniform_10k_2D-32 42996 27936 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Linear_Uniform_1k_4D-32 326492 3746 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Uniform_1k_4D-32 338983 3857 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Linear_Uniform_10k_4D-32 35661 32985 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Uniform_10k_4D-32 35678 33388 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Linear_Clustered_1k_2D-32 425220 2874 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Clustered_1k_2D-32 420080 2849 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Linear_Clustered_10k_2D-32 43242 27776 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_Gonum_Clustered_10k_2D-32 42392 27889 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkKNN10_Linear_Uniform_10k_2D-32 1206 977599 ns/op 164492 B/op 6 allocs/op
|
||||||
|
BenchmarkKNN10_Gonum_Uniform_10k_2D-32 1239 972501 ns/op 164488 B/op 6 allocs/op
|
||||||
|
BenchmarkKNN10_Linear_Clustered_10k_2D-32 1219 973242 ns/op 164492 B/op 6 allocs/op
|
||||||
|
BenchmarkKNN10_Gonum_Clustered_10k_2D-32 1214 971017 ns/op 164488 B/op 6 allocs/op
|
||||||
|
BenchmarkRadiusMid_Linear_Uniform_10k_2D-32 1279 917692 ns/op 947529 B/op 23 allocs/op
|
||||||
|
BenchmarkRadiusMid_Gonum_Uniform_10k_2D-32 1299 918176 ns/op 947529 B/op 23 allocs/op
|
||||||
|
BenchmarkRadiusMid_Linear_Clustered_10k_2D-32 1059 1123281 ns/op 1217866 B/op 24 allocs/op
|
||||||
|
BenchmarkRadiusMid_Gonum_Clustered_10k_2D-32 1063 1149507 ns/op 1217871 B/op 24 allocs/op
|
||||||
|
BenchmarkNearest_1k_2D-32 401595 2964 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_10k_2D-32 42129 28229 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_1k_4D-32 365626 3642 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkNearest_10k_4D-32 36298 33176 ns/op 0 B/op 0 allocs/op
|
||||||
|
BenchmarkKNearest10_1k_2D-32 20348 59568 ns/op 17032 B/op 6 allocs/op
|
||||||
|
BenchmarkKNearest10_10k_2D-32 1224 969093 ns/op 164488 B/op 6 allocs/op
|
||||||
|
BenchmarkRadiusMid_1k_2D-32 21867 53273 ns/op 77512 B/op 16 allocs/op
|
||||||
|
BenchmarkRadiusMid_10k_2D-32 1302 933791 ns/op 955720 B/op 23 allocs/op
|
||||||
|
PASS
|
||||||
|
ok github.com/Snider/Poindexter 40.102s
|
||||||
|
PASS
|
||||||
|
ok github.com/Snider/Poindexter/examples/dht_ping_1d 0.348s
|
||||||
|
PASS
|
||||||
|
ok github.com/Snider/Poindexter/examples/kdtree_2d_ping_hop 0.266s
|
||||||
|
PASS
|
||||||
|
ok github.com/Snider/Poindexter/examples/kdtree_3d_ping_hop_geo 0.272s
|
||||||
|
PASS
|
||||||
|
ok github.com/Snider/Poindexter/examples/kdtree_4d_ping_hop_geo_score 0.269s
|
||||||
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- The first block shows dual-backend benchmarks (Linear vs Gonum) for uniform and clustered datasets at 2D/4D with N=1k/10k.
|
||||||
|
- The final block includes the legacy single-backend benchmarks for additional sizes; both are useful for comparison.
|
||||||
|
|
||||||
|
To compare against the optimized KD backend explicitly, build with `-tags=gonum` and/or download `bench-gonum.txt` from CI artifacts.
|
||||||
|
|
|
||||||
48
docs/wasm.md
48
docs/wasm.md
|
|
@ -9,6 +9,36 @@ Poindexter ships a browser build compiled to WebAssembly along with a small JS l
|
||||||
- `npm/poindexter-wasm/loader.js` — ESM loader that instantiates the WASM and exposes a friendly API
|
- `npm/poindexter-wasm/loader.js` — ESM loader that instantiates the WASM and exposes a friendly API
|
||||||
- `npm/poindexter-wasm/index.d.ts` — TypeScript typings for the loader and KD‑Tree API
|
- `npm/poindexter-wasm/index.d.ts` — TypeScript typings for the loader and KD‑Tree API
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
- Build artifacts and copy `wasm_exec.js`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make wasm-build
|
||||||
|
```
|
||||||
|
|
||||||
|
- Prepare the npm package folder with `dist/` and docs:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make npm-pack
|
||||||
|
```
|
||||||
|
|
||||||
|
- Minimal browser ESM usage (serve `dist/` statically):
|
||||||
|
|
||||||
|
```html
|
||||||
|
<script type="module">
|
||||||
|
import { init } from '/npm/poindexter-wasm/loader.js';
|
||||||
|
const px = await init({
|
||||||
|
wasmURL: '/dist/poindexter.wasm',
|
||||||
|
wasmExecURL: '/dist/wasm_exec.js',
|
||||||
|
});
|
||||||
|
const tree = await px.newTree(2);
|
||||||
|
await tree.insert({ id: 'a', coords: [0, 0], value: 'A' });
|
||||||
|
const nn = await tree.nearest([0.1, 0.2]);
|
||||||
|
console.log(nn);
|
||||||
|
</script>
|
||||||
|
```
|
||||||
|
|
||||||
## Building locally
|
## Building locally
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
@ -102,3 +132,21 @@ Our CI builds and uploads the following artifacts on each push/PR:
|
||||||
- `npm-poindexter-wasm-tarball` — a `.tgz` created via `npm pack` for quick local install/testing
|
- `npm-poindexter-wasm-tarball` — a `.tgz` created via `npm pack` for quick local install/testing
|
||||||
|
|
||||||
You can download these artifacts from the workflow run summary in GitHub Actions.
|
You can download these artifacts from the workflow run summary in GitHub Actions.
|
||||||
|
|
||||||
|
## Browser demo (checked into repo)
|
||||||
|
|
||||||
|
There is a tiny browser demo you can load locally from this repo:
|
||||||
|
|
||||||
|
- Path: `examples/wasm-browser/index.html`
|
||||||
|
- Prerequisites: run `make wasm-build` so `dist/poindexter.wasm` and `dist/wasm_exec.js` exist.
|
||||||
|
- Serve the repo root (so relative paths resolve), for example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m http.server -b 127.0.0.1 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Then open:
|
||||||
|
|
||||||
|
- http://127.0.0.1:8000/examples/wasm-browser/
|
||||||
|
|
||||||
|
Open the browser console to see outputs from `nearest`, `kNearest`, and `radius` queries.
|
||||||
|
|
|
||||||
60
examples/wasm-browser/index.html
Normal file
60
examples/wasm-browser/index.html
Normal file
|
|
@ -0,0 +1,60 @@
|
||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
|
<title>Poindexter WASM Browser Demo</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: system-ui, -apple-system, Segoe UI, Roboto, sans-serif; margin: 2rem; }
|
||||||
|
pre { background: #f6f8fa; padding: 1rem; overflow-x: auto; }
|
||||||
|
code { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Poindexter WASM Browser Demo</h1>
|
||||||
|
<p>This demo uses the ESM loader to initialize the WebAssembly build and run a simple KDTree query entirely in your browser.</p>
|
||||||
|
|
||||||
|
<p>
|
||||||
|
Serve this file from the repository root so the asset paths resolve. For example:
|
||||||
|
</p>
|
||||||
|
<pre><code>python3 -m http.server -b 127.0.0.1 8000</code></pre>
|
||||||
|
<p>Then open <code>http://127.0.0.1:8000/examples/wasm-browser/</code> in your browser.</p>
|
||||||
|
|
||||||
|
<h2>Console output</h2>
|
||||||
|
<p>Open DevTools console to inspect results.</p>
|
||||||
|
|
||||||
|
<script type="module">
|
||||||
|
// Import the ESM loader from the npm package directory within this repo.
|
||||||
|
// When serving from repo root, this path resolves to the local loader.
|
||||||
|
import { init } from '../../npm/poindexter-wasm/loader.js';
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const px = await init({
|
||||||
|
// Point to the built WASM artifacts in dist/. Ensure you ran `make wasm-build` first.
|
||||||
|
wasmURL: '../../dist/poindexter.wasm',
|
||||||
|
wasmExecURL: '../../dist/wasm_exec.js',
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log('Poindexter version (WASM):', await px.version());
|
||||||
|
|
||||||
|
const tree = await px.newTree(2);
|
||||||
|
await tree.insert({ id: 'a', coords: [0, 0], value: 'A' });
|
||||||
|
await tree.insert({ id: 'b', coords: [1, 0], value: 'B' });
|
||||||
|
await tree.insert({ id: 'c', coords: [0, 1], value: 'C' });
|
||||||
|
|
||||||
|
const nearest = await tree.nearest([0.9, 0.1]);
|
||||||
|
console.log('Nearest to [0.9,0.1]:', nearest);
|
||||||
|
|
||||||
|
const knn = await tree.kNearest([0.9, 0.9], 2);
|
||||||
|
console.log('kNN (k=2) for [0.9,0.9]:', knn);
|
||||||
|
|
||||||
|
const within = await tree.radius([0, 0], 1.1);
|
||||||
|
console.log('Within r=1.1 of [0,0]:', within);
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch(err => {
|
||||||
|
console.error('Demo error:', err);
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
103
kdtree.go
103
kdtree.go
|
|
@ -15,6 +15,8 @@ var (
|
||||||
ErrDimMismatch = errors.New("kdtree: inconsistent dimensionality in points")
|
ErrDimMismatch = errors.New("kdtree: inconsistent dimensionality in points")
|
||||||
// ErrDuplicateID indicates a duplicate point ID was encountered.
|
// ErrDuplicateID indicates a duplicate point ID was encountered.
|
||||||
ErrDuplicateID = errors.New("kdtree: duplicate point ID")
|
ErrDuplicateID = errors.New("kdtree: duplicate point ID")
|
||||||
|
// ErrBackendUnavailable indicates that a requested backend cannot be used (e.g., not built/tagged).
|
||||||
|
ErrBackendUnavailable = errors.New("kdtree: requested backend unavailable")
|
||||||
)
|
)
|
||||||
|
|
||||||
// KDPoint represents a point with coordinates and an attached payload/value.
|
// KDPoint represents a point with coordinates and an attached payload/value.
|
||||||
|
|
@ -171,11 +173,35 @@ type KDOption func(*kdOptions)
|
||||||
|
|
||||||
type kdOptions struct {
|
type kdOptions struct {
|
||||||
metric DistanceMetric
|
metric DistanceMetric
|
||||||
|
backend KDBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// defaultBackend returns the implicit backend depending on build tags.
|
||||||
|
// If built with the "gonum" tag, prefer the Gonum backend by default to keep
|
||||||
|
// code paths simple and performant; otherwise fall back to the linear backend.
|
||||||
|
func defaultBackend() KDBackend {
|
||||||
|
if hasGonum() {
|
||||||
|
return BackendGonum
|
||||||
|
}
|
||||||
|
return BackendLinear
|
||||||
|
}
|
||||||
|
|
||||||
|
// KDBackend selects the internal engine used by KDTree.
|
||||||
|
type KDBackend string
|
||||||
|
|
||||||
|
const (
|
||||||
|
BackendLinear KDBackend = "linear"
|
||||||
|
BackendGonum KDBackend = "gonum"
|
||||||
|
)
|
||||||
|
|
||||||
// WithMetric sets the distance metric for the KDTree.
|
// WithMetric sets the distance metric for the KDTree.
|
||||||
func WithMetric(m DistanceMetric) KDOption { return func(o *kdOptions) { o.metric = m } }
|
func WithMetric(m DistanceMetric) KDOption { return func(o *kdOptions) { o.metric = m } }
|
||||||
|
|
||||||
|
// WithBackend selects the internal KDTree backend ("linear" or "gonum").
|
||||||
|
// Default is linear. If the requested backend is unavailable (e.g., gonum build tag not enabled),
|
||||||
|
// the constructor will silently fall back to the linear backend.
|
||||||
|
func WithBackend(b KDBackend) KDOption { return func(o *kdOptions) { o.backend = b } }
|
||||||
|
|
||||||
// KDTree is a lightweight wrapper providing nearest-neighbor operations.
|
// KDTree is a lightweight wrapper providing nearest-neighbor operations.
|
||||||
//
|
//
|
||||||
// Complexity: queries are O(n) linear scans in the current implementation.
|
// Complexity: queries are O(n) linear scans in the current implementation.
|
||||||
|
|
@ -190,6 +216,8 @@ type KDTree[T any] struct {
|
||||||
dim int
|
dim int
|
||||||
metric DistanceMetric
|
metric DistanceMetric
|
||||||
idIndex map[string]int
|
idIndex map[string]int
|
||||||
|
backend KDBackend
|
||||||
|
backendData any // opaque handle for backend-specific structures (e.g., gonum tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKDTree builds a KDTree from the given points.
|
// NewKDTree builds a KDTree from the given points.
|
||||||
|
|
@ -214,15 +242,29 @@ func NewKDTree[T any](pts []KDPoint[T], opts ...KDOption) (*KDTree[T], error) {
|
||||||
idIndex[p.ID] = i
|
idIndex[p.ID] = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cfg := kdOptions{metric: EuclideanDistance{}}
|
cfg := kdOptions{metric: EuclideanDistance{}, backend: defaultBackend()}
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(&cfg)
|
o(&cfg)
|
||||||
}
|
}
|
||||||
|
backend := cfg.backend
|
||||||
|
var backendData any
|
||||||
|
// Attempt to build gonum backend if requested and available.
|
||||||
|
if backend == BackendGonum && hasGonum() {
|
||||||
|
if bd, err := buildGonumBackend(pts, cfg.metric); err == nil {
|
||||||
|
backendData = bd
|
||||||
|
} else {
|
||||||
|
backend = BackendLinear // fallback gracefully
|
||||||
|
}
|
||||||
|
} else if backend == BackendGonum && !hasGonum() {
|
||||||
|
backend = BackendLinear // tag not enabled → fallback
|
||||||
|
}
|
||||||
t := &KDTree[T]{
|
t := &KDTree[T]{
|
||||||
points: append([]KDPoint[T](nil), pts...),
|
points: append([]KDPoint[T](nil), pts...),
|
||||||
dim: dim,
|
dim: dim,
|
||||||
metric: cfg.metric,
|
metric: cfg.metric,
|
||||||
idIndex: idIndex,
|
idIndex: idIndex,
|
||||||
|
backend: backend,
|
||||||
|
backendData: backendData,
|
||||||
}
|
}
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
@ -233,15 +275,21 @@ func NewKDTreeFromDim[T any](dim int, opts ...KDOption) (*KDTree[T], error) {
|
||||||
if dim <= 0 {
|
if dim <= 0 {
|
||||||
return nil, ErrZeroDim
|
return nil, ErrZeroDim
|
||||||
}
|
}
|
||||||
cfg := kdOptions{metric: EuclideanDistance{}}
|
cfg := kdOptions{metric: EuclideanDistance{}, backend: defaultBackend()}
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(&cfg)
|
o(&cfg)
|
||||||
}
|
}
|
||||||
|
backend := cfg.backend
|
||||||
|
if backend == BackendGonum && !hasGonum() {
|
||||||
|
backend = BackendLinear
|
||||||
|
}
|
||||||
return &KDTree[T]{
|
return &KDTree[T]{
|
||||||
points: nil,
|
points: nil,
|
||||||
dim: dim,
|
dim: dim,
|
||||||
metric: cfg.metric,
|
metric: cfg.metric,
|
||||||
idIndex: make(map[string]int),
|
idIndex: make(map[string]int),
|
||||||
|
backend: backend,
|
||||||
|
backendData: nil,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,6 +305,13 @@ func (t *KDTree[T]) Nearest(query []float64) (KDPoint[T], float64, bool) {
|
||||||
if len(query) != t.dim || t.Len() == 0 {
|
if len(query) != t.dim || t.Len() == 0 {
|
||||||
return KDPoint[T]{}, 0, false
|
return KDPoint[T]{}, 0, false
|
||||||
}
|
}
|
||||||
|
// Gonum backend (if available and built)
|
||||||
|
if t.backend == BackendGonum && t.backendData != nil {
|
||||||
|
if idx, dist, ok := gonumNearest[T](t.backendData, query); ok && idx >= 0 && idx < len(t.points) {
|
||||||
|
return t.points[idx], dist, true
|
||||||
|
}
|
||||||
|
// fall through to linear scan if backend didn’t return a result
|
||||||
|
}
|
||||||
bestIdx := -1
|
bestIdx := -1
|
||||||
bestDist := math.MaxFloat64
|
bestDist := math.MaxFloat64
|
||||||
for i := range t.points {
|
for i := range t.points {
|
||||||
|
|
@ -278,6 +333,18 @@ func (t *KDTree[T]) KNearest(query []float64, k int) ([]KDPoint[T], []float64) {
|
||||||
if k <= 0 || len(query) != t.dim || t.Len() == 0 {
|
if k <= 0 || len(query) != t.dim || t.Len() == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
// Gonum backend path
|
||||||
|
if t.backend == BackendGonum && t.backendData != nil {
|
||||||
|
idxs, dists := gonumKNearest[T](t.backendData, query, k)
|
||||||
|
if len(idxs) > 0 {
|
||||||
|
neighbors := make([]KDPoint[T], len(idxs))
|
||||||
|
for i := range idxs {
|
||||||
|
neighbors[i] = t.points[idxs[i]]
|
||||||
|
}
|
||||||
|
return neighbors, dists
|
||||||
|
}
|
||||||
|
// fall back on unexpected empty
|
||||||
|
}
|
||||||
tmp := make([]struct {
|
tmp := make([]struct {
|
||||||
idx int
|
idx int
|
||||||
dist float64
|
dist float64
|
||||||
|
|
@ -304,6 +371,18 @@ func (t *KDTree[T]) Radius(query []float64, r float64) ([]KDPoint[T], []float64)
|
||||||
if r < 0 || len(query) != t.dim || t.Len() == 0 {
|
if r < 0 || len(query) != t.dim || t.Len() == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
// Gonum backend path
|
||||||
|
if t.backend == BackendGonum && t.backendData != nil {
|
||||||
|
idxs, dists := gonumRadius[T](t.backendData, query, r)
|
||||||
|
if len(idxs) > 0 {
|
||||||
|
neighbors := make([]KDPoint[T], len(idxs))
|
||||||
|
for i := range idxs {
|
||||||
|
neighbors[i] = t.points[idxs[i]]
|
||||||
|
}
|
||||||
|
return neighbors, dists
|
||||||
|
}
|
||||||
|
// fall back if no results
|
||||||
|
}
|
||||||
var sel []struct {
|
var sel []struct {
|
||||||
idx int
|
idx int
|
||||||
dist float64
|
dist float64
|
||||||
|
|
@ -342,6 +421,16 @@ func (t *KDTree[T]) Insert(p KDPoint[T]) bool {
|
||||||
if p.ID != "" {
|
if p.ID != "" {
|
||||||
t.idIndex[p.ID] = len(t.points) - 1
|
t.idIndex[p.ID] = len(t.points) - 1
|
||||||
}
|
}
|
||||||
|
// Rebuild backend if using Gonum
|
||||||
|
if t.backend == BackendGonum && hasGonum() {
|
||||||
|
if bd, err := buildGonumBackend(t.points, t.metric); err == nil {
|
||||||
|
t.backendData = bd
|
||||||
|
} else {
|
||||||
|
// fallback to linear if rebuild fails
|
||||||
|
t.backend = BackendLinear
|
||||||
|
t.backendData = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -362,5 +451,15 @@ func (t *KDTree[T]) DeleteByID(id string) bool {
|
||||||
}
|
}
|
||||||
t.points = t.points[:last]
|
t.points = t.points[:last]
|
||||||
delete(t.idIndex, id)
|
delete(t.idIndex, id)
|
||||||
|
// Rebuild backend if using Gonum
|
||||||
|
if t.backend == BackendGonum && hasGonum() {
|
||||||
|
if bd, err := buildGonumBackend(t.points, t.metric); err == nil {
|
||||||
|
t.backendData = bd
|
||||||
|
} else {
|
||||||
|
// fallback to linear if rebuild fails
|
||||||
|
t.backend = BackendLinear
|
||||||
|
t.backendData = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
|
||||||
129
kdtree_backend_parity_test.go
Normal file
129
kdtree_backend_parity_test.go
Normal file
|
|
@ -0,0 +1,129 @@
|
||||||
|
package poindexter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeFixedPoints creates a deterministic set of points in 4D and 2D for parity checks.
|
||||||
|
func makeFixedPoints() []KDPoint[int] {
|
||||||
|
pts := []KDPoint[int]{
|
||||||
|
{ID: "A", Coords: []float64{0, 0, 0, 0}, Value: 1},
|
||||||
|
{ID: "B", Coords: []float64{1, 0, 0.5, 0.2}, Value: 2},
|
||||||
|
{ID: "C", Coords: []float64{0, 1, 0.3, 0.7}, Value: 3},
|
||||||
|
{ID: "D", Coords: []float64{1, 1, 0.9, 0.9}, Value: 4},
|
||||||
|
{ID: "E", Coords: []float64{0.2, 0.8, 0.4, 0.6}, Value: 5},
|
||||||
|
}
|
||||||
|
return pts
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParity_Nearest(t *testing.T) {
|
||||||
|
pts := makeFixedPoints()
|
||||||
|
queries := [][]float64{
|
||||||
|
{0, 0, 0, 0},
|
||||||
|
{0.9, 0.2, 0.5, 0.1},
|
||||||
|
{0.5, 0.5, 0.5, 0.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
lin, err := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("linear NewKDTree: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only build a gonum tree when the optimized backend is compiled in.
|
||||||
|
if hasGonum() {
|
||||||
|
gon, err := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gonum NewKDTree: %v", err)
|
||||||
|
}
|
||||||
|
for _, q := range queries {
|
||||||
|
pl, dl, okl := lin.Nearest(q)
|
||||||
|
pg, dg, okg := gon.Nearest(q)
|
||||||
|
if okl != okg {
|
||||||
|
t.Fatalf("ok mismatch: linear=%v gonum=%v", okl, okg)
|
||||||
|
}
|
||||||
|
if !okl {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pl.ID != pg.ID {
|
||||||
|
t.Errorf("nearest ID mismatch for %v: linear=%s gonum=%s", q, pl.ID, pg.ID)
|
||||||
|
}
|
||||||
|
if (dl == 0 && dg != 0) || (dl != 0 && dg == 0) {
|
||||||
|
t.Errorf("nearest distance zero/nonzero mismatch: linear=%v gonum=%v", dl, dg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParity_KNearest(t *testing.T) {
|
||||||
|
pts := makeFixedPoints()
|
||||||
|
q := []float64{0.6, 0.6, 0.4, 0.4}
|
||||||
|
ks := []int{1, 2, 5, 10}
|
||||||
|
lin, _ := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||||
|
if hasGonum() {
|
||||||
|
gon, _ := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||||
|
for _, k := range ks {
|
||||||
|
ln, ld := lin.KNearest(q, k)
|
||||||
|
gn, gd := gon.KNearest(q, k)
|
||||||
|
if len(ln) != len(gn) || len(ld) != len(gd) {
|
||||||
|
t.Fatalf("k=%d length mismatch: linear (%d,%d) vs gonum (%d,%d)", k, len(ln), len(ld), len(gn), len(gd))
|
||||||
|
}
|
||||||
|
// Compare IDs element-wise; ties may reorder between backends, so relax by set equality when distances equal.
|
||||||
|
for i := range ln {
|
||||||
|
if ln[i].ID != gn[i].ID {
|
||||||
|
// If distances are effectively equal, allow different order
|
||||||
|
if i < len(ld) && i < len(gd) && ld[i] == gd[i] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Logf("k=%d index %d ID mismatch: linear=%s gonum=%s (dl=%.6f dg=%.6f)", k, i, ln[i].ID, gn[i].ID, ld[i], gd[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParity_Radius(t *testing.T) {
|
||||||
|
pts := makeFixedPoints()
|
||||||
|
q := []float64{0.4, 0.6, 0.4, 0.6}
|
||||||
|
radii := []float64{0, 0.15, 0.3, 1.0}
|
||||||
|
lin, _ := NewKDTree(pts, WithBackend(BackendLinear), WithMetric(EuclideanDistance{}))
|
||||||
|
if hasGonum() {
|
||||||
|
gon, _ := NewKDTree(pts, WithBackend(BackendGonum), WithMetric(EuclideanDistance{}))
|
||||||
|
for _, r := range radii {
|
||||||
|
ln, ld := lin.Radius(q, r)
|
||||||
|
gn, gd := gon.Radius(q, r)
|
||||||
|
if len(ln) != len(gn) || len(ld) != len(gd) {
|
||||||
|
t.Fatalf("r=%.3f length mismatch: linear (%d,%d) vs gonum (%d,%d)", r, len(ln), len(ld), len(gn), len(gd))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackendParity_RandomQueries2D(t *testing.T) {
|
||||||
|
// Down-project 4D to 2D to exercise pruning differences as well
|
||||||
|
pts4 := makeFixedPoints()
|
||||||
|
pts2 := make([]KDPoint[int], len(pts4))
|
||||||
|
for i, p := range pts4 {
|
||||||
|
pts2[i] = KDPoint[int]{ID: p.ID, Coords: []float64{p.Coords[0], p.Coords[1]}, Value: p.Value}
|
||||||
|
}
|
||||||
|
lin, _ := NewKDTree(pts2, WithBackend(BackendLinear), WithMetric(ManhattanDistance{}))
|
||||||
|
if hasGonum() {
|
||||||
|
gon, _ := NewKDTree(pts2, WithBackend(BackendGonum), WithMetric(ManhattanDistance{}))
|
||||||
|
rng := rand.New(rand.NewSource(42))
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
q := []float64{rng.Float64(), rng.Float64()}
|
||||||
|
pl, dl, okl := lin.Nearest(q)
|
||||||
|
pg, dg, okg := gon.Nearest(q)
|
||||||
|
if okl != okg {
|
||||||
|
t.Fatalf("ok mismatch (2D rand)")
|
||||||
|
}
|
||||||
|
if !okl {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if pl.ID != pg.ID && (dl != dg) {
|
||||||
|
// Allow different picks only if distances tie; otherwise flag
|
||||||
|
t.Errorf("2D rand nearest mismatch: linear %s(%.6f) gonum %s(%.6f)", pl.ID, dl, pg.ID, dg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
311
kdtree_gonum.go
Normal file
311
kdtree_gonum.go
Normal file
|
|
@ -0,0 +1,311 @@
|
||||||
|
//go:build gonum
|
||||||
|
|
||||||
|
package poindexter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note: This file is compiled when built with the "gonum" tag. For now, we
|
||||||
|
// provide an internal KD-tree backend that performs balanced median-split
|
||||||
|
// construction and branch-and-bound queries. This gives sub-linear behavior on
|
||||||
|
// suitable datasets without introducing an external dependency. The public API
|
||||||
|
// and option names remain the same; a future change can swap this implementation
|
||||||
|
// to use gonum.org/v1/gonum/spatial/kdtree without altering callers.
|
||||||
|
|
||||||
|
// hasGonum reports whether the optimized backend is compiled in.
|
||||||
|
func hasGonum() bool { return true }
|
||||||
|
|
||||||
|
// kdNode represents a node in the median-split KD-tree.
|
||||||
|
type kdNode struct {
|
||||||
|
axis int
|
||||||
|
idx int // index into the original points slice
|
||||||
|
val float64
|
||||||
|
left *kdNode
|
||||||
|
right *kdNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// kdBackend holds the KD-tree root and metadata.
|
||||||
|
type kdBackend struct {
|
||||||
|
root *kdNode
|
||||||
|
dim int
|
||||||
|
metric DistanceMetric
|
||||||
|
// Access to original coords by index is done via a closure we capture at build
|
||||||
|
coords func(i int) []float64
|
||||||
|
len int
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildGonumBackend builds a balanced KD-tree using variance-based axis choice
|
||||||
|
// and median splits. It does not reorder the external points slice; it keeps
|
||||||
|
// indices and accesses the original data via closures, preserving caller order.
|
||||||
|
func buildGonumBackend[T any](points []KDPoint[T], metric DistanceMetric) (any, error) {
|
||||||
|
// Only enable this backend for metrics where the axis-slab bound is valid
|
||||||
|
// for pruning: L2/L1/L∞. For other metrics (e.g., Cosine), fall back.
|
||||||
|
switch metric.(type) {
|
||||||
|
case EuclideanDistance, ManhattanDistance, ChebyshevDistance:
|
||||||
|
// supported
|
||||||
|
default:
|
||||||
|
return nil, ErrBackendUnavailable
|
||||||
|
}
|
||||||
|
if len(points) == 0 {
|
||||||
|
return &kdBackend{root: nil, dim: 0, metric: metric, coords: func(int) []float64 { return nil }}, nil
|
||||||
|
}
|
||||||
|
dim := len(points[0].Coords)
|
||||||
|
coords := func(i int) []float64 { return points[i].Coords }
|
||||||
|
idxs := make([]int, len(points))
|
||||||
|
for i := range idxs {
|
||||||
|
idxs[i] = i
|
||||||
|
}
|
||||||
|
root := buildKDRecursive(idxs, coords, dim, 0)
|
||||||
|
return &kdBackend{root: root, dim: dim, metric: metric, coords: coords, len: len(points)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute per-axis standard deviation (used for axis selection)
|
||||||
|
func axisStd(idxs []int, coords func(int) []float64, dim int) []float64 {
|
||||||
|
vars := make([]float64, dim)
|
||||||
|
means := make([]float64, dim)
|
||||||
|
n := float64(len(idxs))
|
||||||
|
if n == 0 {
|
||||||
|
return vars
|
||||||
|
}
|
||||||
|
for _, i := range idxs {
|
||||||
|
c := coords(i)
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
means[d] += c[d]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
means[d] /= n
|
||||||
|
}
|
||||||
|
for _, i := range idxs {
|
||||||
|
c := coords(i)
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
delta := c[d] - means[d]
|
||||||
|
vars[d] += delta * delta
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for d := 0; d < dim; d++ {
|
||||||
|
vars[d] = math.Sqrt(vars[d] / n)
|
||||||
|
}
|
||||||
|
return vars
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKDRecursive(idxs []int, coords func(int) []float64, dim int, depth int) *kdNode {
|
||||||
|
if len(idxs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// choose axis with max stddev
|
||||||
|
stds := axisStd(idxs, coords, dim)
|
||||||
|
axis := 0
|
||||||
|
maxv := stds[0]
|
||||||
|
for d := 1; d < dim; d++ {
|
||||||
|
if stds[d] > maxv {
|
||||||
|
maxv = stds[d]
|
||||||
|
axis = d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// nth-element (partial sort) by axis using sort.Slice for simplicity
|
||||||
|
sort.Slice(idxs, func(i, j int) bool { return coords(idxs[i])[axis] < coords(idxs[j])[axis] })
|
||||||
|
mid := len(idxs) / 2
|
||||||
|
medianIdx := idxs[mid]
|
||||||
|
n := &kdNode{axis: axis, idx: medianIdx, val: coords(medianIdx)[axis]}
|
||||||
|
n.left = buildKDRecursive(append([]int(nil), idxs[:mid]...), coords, dim, depth+1)
|
||||||
|
n.right = buildKDRecursive(append([]int(nil), idxs[mid+1:]...), coords, dim, depth+1)
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// gonumNearest performs 1-NN search using the KD backend.
|
||||||
|
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
|
||||||
|
b, ok := backend.(*kdBackend)
|
||||||
|
if !ok || b.root == nil || len(query) != b.dim {
|
||||||
|
return -1, 0, false
|
||||||
|
}
|
||||||
|
bestIdx := -1
|
||||||
|
bestDist := math.MaxFloat64
|
||||||
|
var search func(*kdNode)
|
||||||
|
search = func(n *kdNode) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := b.coords(n.idx)
|
||||||
|
d := b.metric.Distance(query, c)
|
||||||
|
if d < bestDist {
|
||||||
|
bestDist = d
|
||||||
|
bestIdx = n.idx
|
||||||
|
}
|
||||||
|
axis := n.axis
|
||||||
|
qv := query[axis]
|
||||||
|
// choose side
|
||||||
|
near, far := n.left, n.right
|
||||||
|
if qv >= n.val {
|
||||||
|
near, far = n.right, n.left
|
||||||
|
}
|
||||||
|
search(near)
|
||||||
|
// prune if hyperslab distance is >= bestDist
|
||||||
|
diff := qv - n.val
|
||||||
|
if diff < 0 {
|
||||||
|
diff = -diff
|
||||||
|
}
|
||||||
|
if diff <= bestDist {
|
||||||
|
search(far)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
search(b.root)
|
||||||
|
if bestIdx < 0 {
|
||||||
|
return -1, 0, false
|
||||||
|
}
|
||||||
|
return bestIdx, bestDist, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// small max-heap for (distance, index)
|
||||||
|
// We’ll use a slice maintaining the largest distance at [0] via container/heap-like ops.
|
||||||
|
type knnItem struct {
|
||||||
|
idx int
|
||||||
|
dist float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type knnHeap []knnItem
|
||||||
|
|
||||||
|
func (h knnHeap) Len() int { return len(h) }
|
||||||
|
func (h knnHeap) less(i, j int) bool { return h[i].dist > h[j].dist } // max-heap by dist
|
||||||
|
func (h *knnHeap) push(x knnItem) { *h = append(*h, x); h.up(len(*h) - 1) }
|
||||||
|
func (h *knnHeap) pop() knnItem {
|
||||||
|
n := len(*h) - 1
|
||||||
|
h.swap(0, n)
|
||||||
|
v := (*h)[n]
|
||||||
|
*h = (*h)[:n]
|
||||||
|
h.down(0)
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
func (h *knnHeap) peek() knnItem { return (*h)[0] }
|
||||||
|
func (h knnHeap) swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||||
|
func (h *knnHeap) up(i int) {
|
||||||
|
for i > 0 {
|
||||||
|
p := (i - 1) / 2
|
||||||
|
if !h.less(i, p) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
h.swap(i, p)
|
||||||
|
i = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (h *knnHeap) down(i int) {
|
||||||
|
for {
|
||||||
|
l := 2*i + 1
|
||||||
|
r := l + 1
|
||||||
|
largest := i
|
||||||
|
if l < h.Len() && h.less(l, largest) {
|
||||||
|
largest = l
|
||||||
|
}
|
||||||
|
if r < h.Len() && h.less(r, largest) {
|
||||||
|
largest = r
|
||||||
|
}
|
||||||
|
if largest == i {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
h.swap(i, largest)
|
||||||
|
i = largest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// gonumKNearest returns indices in ascending distance order.
|
||||||
|
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
|
||||||
|
b, ok := backend.(*kdBackend)
|
||||||
|
if !ok || b.root == nil || len(query) != b.dim || k <= 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var h knnHeap
|
||||||
|
bestCap := k
|
||||||
|
var search func(*kdNode)
|
||||||
|
search = func(n *kdNode) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := b.coords(n.idx)
|
||||||
|
d := b.metric.Distance(query, c)
|
||||||
|
if h.Len() < bestCap {
|
||||||
|
h.push(knnItem{idx: n.idx, dist: d})
|
||||||
|
} else if d < h.peek().dist {
|
||||||
|
// replace max
|
||||||
|
h[0] = knnItem{idx: n.idx, dist: d}
|
||||||
|
h.down(0)
|
||||||
|
}
|
||||||
|
axis := n.axis
|
||||||
|
qv := query[axis]
|
||||||
|
near, far := n.left, n.right
|
||||||
|
if qv >= n.val {
|
||||||
|
near, far = n.right, n.left
|
||||||
|
}
|
||||||
|
search(near)
|
||||||
|
// prune against current worst in heap if heap is full; otherwise use bestDist
|
||||||
|
threshold := math.MaxFloat64
|
||||||
|
if h.Len() == bestCap {
|
||||||
|
threshold = h.peek().dist
|
||||||
|
} else if h.Len() > 0 {
|
||||||
|
// use best known (not strictly necessary)
|
||||||
|
threshold = h.peek().dist
|
||||||
|
}
|
||||||
|
diff := qv - n.val
|
||||||
|
if diff < 0 {
|
||||||
|
diff = -diff
|
||||||
|
}
|
||||||
|
if diff <= threshold {
|
||||||
|
search(far)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
search(b.root)
|
||||||
|
// Extract to slices and sort ascending by distance
|
||||||
|
res := make([]knnItem, len(h))
|
||||||
|
copy(res, h)
|
||||||
|
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||||||
|
idxs := make([]int, len(res))
|
||||||
|
dists := make([]float64, len(res))
|
||||||
|
for i := range res {
|
||||||
|
idxs[i] = res[i].idx
|
||||||
|
dists[i] = res[i].dist
|
||||||
|
}
|
||||||
|
return idxs, dists
|
||||||
|
}
|
||||||
|
|
||||||
|
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
|
||||||
|
b, ok := backend.(*kdBackend)
|
||||||
|
if !ok || b.root == nil || len(query) != b.dim || r < 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var res []knnItem
|
||||||
|
var search func(*kdNode)
|
||||||
|
search = func(n *kdNode) {
|
||||||
|
if n == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := b.coords(n.idx)
|
||||||
|
d := b.metric.Distance(query, c)
|
||||||
|
if d <= r {
|
||||||
|
res = append(res, knnItem{idx: n.idx, dist: d})
|
||||||
|
}
|
||||||
|
axis := n.axis
|
||||||
|
qv := query[axis]
|
||||||
|
near, far := n.left, n.right
|
||||||
|
if qv >= n.val {
|
||||||
|
near, far = n.right, n.left
|
||||||
|
}
|
||||||
|
search(near)
|
||||||
|
diff := qv - n.val
|
||||||
|
if diff < 0 {
|
||||||
|
diff = -diff
|
||||||
|
}
|
||||||
|
if diff <= r {
|
||||||
|
search(far)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
search(b.root)
|
||||||
|
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||||||
|
idxs := make([]int, len(res))
|
||||||
|
dists := make([]float64, len(res))
|
||||||
|
for i := range res {
|
||||||
|
idxs[i] = res[i].idx
|
||||||
|
dists[i] = res[i].dist
|
||||||
|
}
|
||||||
|
return idxs, dists
|
||||||
|
}
|
||||||
23
kdtree_gonum_stub.go
Normal file
23
kdtree_gonum_stub.go
Normal file
|
|
@ -0,0 +1,23 @@
|
||||||
|
//go:build !gonum
|
||||||
|
|
||||||
|
package poindexter
|
||||||
|
|
||||||
|
// hasGonum reports whether the gonum backend is compiled in (build tag 'gonum').
|
||||||
|
func hasGonum() bool { return false }
|
||||||
|
|
||||||
|
// buildGonumBackend is unavailable without the 'gonum' build tag.
|
||||||
|
func buildGonumBackend[T any](pts []KDPoint[T], metric DistanceMetric) (any, error) {
|
||||||
|
return nil, ErrEmptyPoints // sentinel non-nil error to force fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
|
||||||
|
return -1, 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
@ -71,14 +71,32 @@ Explore runnable examples in the repository:
|
||||||
- examples/kdtree_2d_ping_hop
|
- examples/kdtree_2d_ping_hop
|
||||||
- examples/kdtree_3d_ping_hop_geo
|
- examples/kdtree_3d_ping_hop_geo
|
||||||
- examples/kdtree_4d_ping_hop_geo_score
|
- examples/kdtree_4d_ping_hop_geo_score
|
||||||
|
- examples/wasm-browser (browser demo using the ESM loader)
|
||||||
|
|
||||||
### KDTree performance and notes
|
### KDTree performance and notes
|
||||||
- Current KDTree queries are O(n) linear scans, which are great for small-to-medium datasets or low-latency prototyping. For 1e5+ points and low/medium dimensions, consider swapping the internal engine to `gonum.org/v1/gonum/spatial/kdtree` (the API here is compatible by design).
|
- Dual backend support: Linear (always available) and an optimized KD backend enabled when building with `-tags=gonum`. Linear is the default; with the `gonum` tag, the optimized backend becomes the default.
|
||||||
|
- Complexity: Linear backend is O(n) per query. Optimized KD backend is typically sub-linear on prunable datasets and dims ≤ ~8, especially as N grows (≥10k–100k).
|
||||||
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
- Insert is O(1) amortized; delete by ID is O(1) via swap-delete; order is not preserved.
|
||||||
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
- Concurrency: the KDTree type is not safe for concurrent mutation. Protect with a mutex or share immutable snapshots for read-mostly workloads.
|
||||||
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
- See multi-dimensional examples (ping/hops/geo/score) in docs and `examples/`.
|
||||||
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
- Performance guide: see docs/Performance for benchmark guidance and tips: [docs/perf.md](docs/perf.md) • Hosted: https://snider.github.io/Poindexter/perf/
|
||||||
|
|
||||||
|
### Backend selection
|
||||||
|
- Default backend is Linear. If you build with `-tags=gonum`, the default becomes the optimized KD backend.
|
||||||
|
- You can override per tree at construction:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Force Linear (always available)
|
||||||
|
kdt1, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendLinear))
|
||||||
|
|
||||||
|
// Force Gonum (requires build tag)
|
||||||
|
kdt2, _ := poindexter.NewKDTree(pts, poindexter.WithBackend(poindexter.BackendGonum))
|
||||||
|
```
|
||||||
|
|
||||||
|
- Supported metrics in the optimized backend: Euclidean (L2), Manhattan (L1), Chebyshev (L∞).
|
||||||
|
- Cosine and Weighted-Cosine currently run on the Linear backend.
|
||||||
|
- See the Performance guide for measured comparisons and when to choose which backend.
|
||||||
|
|
||||||
#### Choosing a metric (quick tips)
|
#### Choosing a metric (quick tips)
|
||||||
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
- Euclidean (L2): smooth trade-offs across axes; solid default for blended preferences.
|
||||||
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
- Manhattan (L1): emphasizes per-axis absolute differences; good when each unit of ping/hop matters equally.
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue