Poindexter/kdtree_gonum.go

312 lines
7.7 KiB
Go
Raw Permalink Normal View History

//go:build gonum
package poindexter
import (
"math"
"sort"
)
// Note: This file is compiled when built with the "gonum" tag. For now, we
// provide an internal KD-tree backend that performs balanced median-split
// construction and branch-and-bound queries. This gives sub-linear behavior on
// suitable datasets without introducing an external dependency. The public API
// and option names remain the same; a future change can swap this implementation
// to use gonum.org/v1/gonum/spatial/kdtree without altering callers.
// hasGonum reports whether the optimized backend is compiled in.
func hasGonum() bool { return true }
// kdNode represents a node in the median-split KD-tree.
type kdNode struct {
axis int
idx int // index into the original points slice
val float64
left *kdNode
right *kdNode
}
// kdBackend holds the KD-tree root and metadata.
type kdBackend struct {
root *kdNode
dim int
metric DistanceMetric
// Access to original coords by index is done via a closure we capture at build
coords func(i int) []float64
len int
}
// buildGonumBackend builds a balanced KD-tree using variance-based axis choice
// and median splits. It does not reorder the external points slice; it keeps
// indices and accesses the original data via closures, preserving caller order.
func buildGonumBackend[T any](points []KDPoint[T], metric DistanceMetric) (any, error) {
// Only enable this backend for metrics where the axis-slab bound is valid
// for pruning: L2/L1/L∞. For other metrics (e.g., Cosine), fall back.
switch metric.(type) {
case EuclideanDistance, ManhattanDistance, ChebyshevDistance:
// supported
default:
return nil, ErrBackendUnavailable
}
if len(points) == 0 {
return &kdBackend{root: nil, dim: 0, metric: metric, coords: func(int) []float64 { return nil }}, nil
}
dim := len(points[0].Coords)
coords := func(i int) []float64 { return points[i].Coords }
idxs := make([]int, len(points))
for i := range idxs {
idxs[i] = i
}
root := buildKDRecursive(idxs, coords, dim, 0)
return &kdBackend{root: root, dim: dim, metric: metric, coords: coords, len: len(points)}, nil
}
// compute per-axis standard deviation (used for axis selection)
func axisStd(idxs []int, coords func(int) []float64, dim int) []float64 {
vars := make([]float64, dim)
means := make([]float64, dim)
n := float64(len(idxs))
if n == 0 {
return vars
}
for _, i := range idxs {
c := coords(i)
for d := 0; d < dim; d++ {
means[d] += c[d]
}
}
for d := 0; d < dim; d++ {
means[d] /= n
}
for _, i := range idxs {
c := coords(i)
for d := 0; d < dim; d++ {
delta := c[d] - means[d]
vars[d] += delta * delta
}
}
for d := 0; d < dim; d++ {
vars[d] = math.Sqrt(vars[d] / n)
}
return vars
}
func buildKDRecursive(idxs []int, coords func(int) []float64, dim int, depth int) *kdNode {
if len(idxs) == 0 {
return nil
}
// choose axis with max stddev
stds := axisStd(idxs, coords, dim)
axis := 0
maxv := stds[0]
for d := 1; d < dim; d++ {
if stds[d] > maxv {
maxv = stds[d]
axis = d
}
}
// nth-element (partial sort) by axis using sort.Slice for simplicity
sort.Slice(idxs, func(i, j int) bool { return coords(idxs[i])[axis] < coords(idxs[j])[axis] })
mid := len(idxs) / 2
medianIdx := idxs[mid]
n := &kdNode{axis: axis, idx: medianIdx, val: coords(medianIdx)[axis]}
n.left = buildKDRecursive(append([]int(nil), idxs[:mid]...), coords, dim, depth+1)
n.right = buildKDRecursive(append([]int(nil), idxs[mid+1:]...), coords, dim, depth+1)
return n
}
// gonumNearest performs 1-NN search using the KD backend.
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
b, ok := backend.(*kdBackend)
if !ok || b.root == nil || len(query) != b.dim {
return -1, 0, false
}
bestIdx := -1
bestDist := math.MaxFloat64
var search func(*kdNode)
search = func(n *kdNode) {
if n == nil {
return
}
c := b.coords(n.idx)
d := b.metric.Distance(query, c)
if d < bestDist {
bestDist = d
bestIdx = n.idx
}
axis := n.axis
qv := query[axis]
// choose side
near, far := n.left, n.right
if qv >= n.val {
near, far = n.right, n.left
}
search(near)
// prune if hyperslab distance is >= bestDist
diff := qv - n.val
if diff < 0 {
diff = -diff
}
if diff <= bestDist {
search(far)
}
}
search(b.root)
if bestIdx < 0 {
return -1, 0, false
}
return bestIdx, bestDist, true
}
// small max-heap for (distance, index)
// Well use a slice maintaining the largest distance at [0] via container/heap-like ops.
type knnItem struct {
idx int
dist float64
}
type knnHeap []knnItem
func (h knnHeap) Len() int { return len(h) }
func (h knnHeap) less(i, j int) bool { return h[i].dist > h[j].dist } // max-heap by dist
func (h *knnHeap) push(x knnItem) { *h = append(*h, x); h.up(len(*h) - 1) }
func (h *knnHeap) pop() knnItem {
n := len(*h) - 1
h.swap(0, n)
v := (*h)[n]
*h = (*h)[:n]
h.down(0)
return v
}
func (h *knnHeap) peek() knnItem { return (*h)[0] }
func (h knnHeap) swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *knnHeap) up(i int) {
for i > 0 {
p := (i - 1) / 2
if !h.less(i, p) {
break
}
h.swap(i, p)
i = p
}
}
func (h *knnHeap) down(i int) {
for {
l := 2*i + 1
r := l + 1
largest := i
if l < h.Len() && h.less(l, largest) {
largest = l
}
if r < h.Len() && h.less(r, largest) {
largest = r
}
if largest == i {
break
}
h.swap(i, largest)
i = largest
}
}
// gonumKNearest returns indices in ascending distance order.
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
b, ok := backend.(*kdBackend)
if !ok || b.root == nil || len(query) != b.dim || k <= 0 {
return nil, nil
}
var h knnHeap
bestCap := k
var search func(*kdNode)
search = func(n *kdNode) {
if n == nil {
return
}
c := b.coords(n.idx)
d := b.metric.Distance(query, c)
if h.Len() < bestCap {
h.push(knnItem{idx: n.idx, dist: d})
} else if d < h.peek().dist {
// replace max
h[0] = knnItem{idx: n.idx, dist: d}
h.down(0)
}
axis := n.axis
qv := query[axis]
near, far := n.left, n.right
if qv >= n.val {
near, far = n.right, n.left
}
search(near)
// prune against current worst in heap if heap is full; otherwise use bestDist
threshold := math.MaxFloat64
if h.Len() == bestCap {
threshold = h.peek().dist
} else if h.Len() > 0 {
// use best known (not strictly necessary)
threshold = h.peek().dist
}
diff := qv - n.val
if diff < 0 {
diff = -diff
}
if diff <= threshold {
search(far)
}
}
search(b.root)
// Extract to slices and sort ascending by distance
res := make([]knnItem, len(h))
copy(res, h)
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
idxs := make([]int, len(res))
dists := make([]float64, len(res))
for i := range res {
idxs[i] = res[i].idx
dists[i] = res[i].dist
}
return idxs, dists
}
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
b, ok := backend.(*kdBackend)
if !ok || b.root == nil || len(query) != b.dim || r < 0 {
return nil, nil
}
var res []knnItem
var search func(*kdNode)
search = func(n *kdNode) {
if n == nil {
return
}
c := b.coords(n.idx)
d := b.metric.Distance(query, c)
if d <= r {
res = append(res, knnItem{idx: n.idx, dist: d})
}
axis := n.axis
qv := query[axis]
near, far := n.left, n.right
if qv >= n.val {
near, far = n.right, n.left
}
search(near)
diff := qv - n.val
if diff < 0 {
diff = -diff
}
if diff <= r {
search(far)
}
}
search(b.root)
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
idxs := make([]int, len(res))
dists := make([]float64, len(res))
for i := range res {
idxs[i] = res[i].idx
dists[i] = res[i].dist
}
return idxs, dists
}