311 lines
7.7 KiB
Go
311 lines
7.7 KiB
Go
//go:build gonum
|
||
|
||
package poindexter
|
||
|
||
import (
|
||
"math"
|
||
"sort"
|
||
)
|
||
|
||
// Note: This file is compiled when built with the "gonum" tag. For now, we
|
||
// provide an internal KD-tree backend that performs balanced median-split
|
||
// construction and branch-and-bound queries. This gives sub-linear behavior on
|
||
// suitable datasets without introducing an external dependency. The public API
|
||
// and option names remain the same; a future change can swap this implementation
|
||
// to use gonum.org/v1/gonum/spatial/kdtree without altering callers.
|
||
|
||
// hasGonum reports whether the optimized backend is compiled in.
|
||
func hasGonum() bool { return true }
|
||
|
||
// kdNode represents a node in the median-split KD-tree.
|
||
type kdNode struct {
|
||
axis int
|
||
idx int // index into the original points slice
|
||
val float64
|
||
left *kdNode
|
||
right *kdNode
|
||
}
|
||
|
||
// kdBackend holds the KD-tree root and metadata.
|
||
type kdBackend struct {
|
||
root *kdNode
|
||
dim int
|
||
metric DistanceMetric
|
||
// Access to original coords by index is done via a closure we capture at build
|
||
coords func(i int) []float64
|
||
len int
|
||
}
|
||
|
||
// buildGonumBackend builds a balanced KD-tree using variance-based axis choice
|
||
// and median splits. It does not reorder the external points slice; it keeps
|
||
// indices and accesses the original data via closures, preserving caller order.
|
||
func buildGonumBackend[T any](points []KDPoint[T], metric DistanceMetric) (any, error) {
|
||
// Only enable this backend for metrics where the axis-slab bound is valid
|
||
// for pruning: L2/L1/L∞. For other metrics (e.g., Cosine), fall back.
|
||
switch metric.(type) {
|
||
case EuclideanDistance, ManhattanDistance, ChebyshevDistance:
|
||
// supported
|
||
default:
|
||
return nil, ErrBackendUnavailable
|
||
}
|
||
if len(points) == 0 {
|
||
return &kdBackend{root: nil, dim: 0, metric: metric, coords: func(int) []float64 { return nil }}, nil
|
||
}
|
||
dim := len(points[0].Coords)
|
||
coords := func(i int) []float64 { return points[i].Coords }
|
||
idxs := make([]int, len(points))
|
||
for i := range idxs {
|
||
idxs[i] = i
|
||
}
|
||
root := buildKDRecursive(idxs, coords, dim, 0)
|
||
return &kdBackend{root: root, dim: dim, metric: metric, coords: coords, len: len(points)}, nil
|
||
}
|
||
|
||
// compute per-axis standard deviation (used for axis selection)
|
||
func axisStd(idxs []int, coords func(int) []float64, dim int) []float64 {
|
||
vars := make([]float64, dim)
|
||
means := make([]float64, dim)
|
||
n := float64(len(idxs))
|
||
if n == 0 {
|
||
return vars
|
||
}
|
||
for _, i := range idxs {
|
||
c := coords(i)
|
||
for d := 0; d < dim; d++ {
|
||
means[d] += c[d]
|
||
}
|
||
}
|
||
for d := 0; d < dim; d++ {
|
||
means[d] /= n
|
||
}
|
||
for _, i := range idxs {
|
||
c := coords(i)
|
||
for d := 0; d < dim; d++ {
|
||
delta := c[d] - means[d]
|
||
vars[d] += delta * delta
|
||
}
|
||
}
|
||
for d := 0; d < dim; d++ {
|
||
vars[d] = math.Sqrt(vars[d] / n)
|
||
}
|
||
return vars
|
||
}
|
||
|
||
func buildKDRecursive(idxs []int, coords func(int) []float64, dim int, depth int) *kdNode {
|
||
if len(idxs) == 0 {
|
||
return nil
|
||
}
|
||
// choose axis with max stddev
|
||
stds := axisStd(idxs, coords, dim)
|
||
axis := 0
|
||
maxv := stds[0]
|
||
for d := 1; d < dim; d++ {
|
||
if stds[d] > maxv {
|
||
maxv = stds[d]
|
||
axis = d
|
||
}
|
||
}
|
||
// nth-element (partial sort) by axis using sort.Slice for simplicity
|
||
sort.Slice(idxs, func(i, j int) bool { return coords(idxs[i])[axis] < coords(idxs[j])[axis] })
|
||
mid := len(idxs) / 2
|
||
medianIdx := idxs[mid]
|
||
n := &kdNode{axis: axis, idx: medianIdx, val: coords(medianIdx)[axis]}
|
||
n.left = buildKDRecursive(append([]int(nil), idxs[:mid]...), coords, dim, depth+1)
|
||
n.right = buildKDRecursive(append([]int(nil), idxs[mid+1:]...), coords, dim, depth+1)
|
||
return n
|
||
}
|
||
|
||
// gonumNearest performs 1-NN search using the KD backend.
|
||
func gonumNearest[T any](backend any, query []float64) (int, float64, bool) {
|
||
b, ok := backend.(*kdBackend)
|
||
if !ok || b.root == nil || len(query) != b.dim {
|
||
return -1, 0, false
|
||
}
|
||
bestIdx := -1
|
||
bestDist := math.MaxFloat64
|
||
var search func(*kdNode)
|
||
search = func(n *kdNode) {
|
||
if n == nil {
|
||
return
|
||
}
|
||
c := b.coords(n.idx)
|
||
d := b.metric.Distance(query, c)
|
||
if d < bestDist {
|
||
bestDist = d
|
||
bestIdx = n.idx
|
||
}
|
||
axis := n.axis
|
||
qv := query[axis]
|
||
// choose side
|
||
near, far := n.left, n.right
|
||
if qv >= n.val {
|
||
near, far = n.right, n.left
|
||
}
|
||
search(near)
|
||
// prune if hyperslab distance is >= bestDist
|
||
diff := qv - n.val
|
||
if diff < 0 {
|
||
diff = -diff
|
||
}
|
||
if diff <= bestDist {
|
||
search(far)
|
||
}
|
||
}
|
||
search(b.root)
|
||
if bestIdx < 0 {
|
||
return -1, 0, false
|
||
}
|
||
return bestIdx, bestDist, true
|
||
}
|
||
|
||
// small max-heap for (distance, index)
|
||
// We’ll use a slice maintaining the largest distance at [0] via container/heap-like ops.
|
||
type knnItem struct {
|
||
idx int
|
||
dist float64
|
||
}
|
||
|
||
type knnHeap []knnItem
|
||
|
||
func (h knnHeap) Len() int { return len(h) }
|
||
func (h knnHeap) less(i, j int) bool { return h[i].dist > h[j].dist } // max-heap by dist
|
||
func (h *knnHeap) push(x knnItem) { *h = append(*h, x); h.up(len(*h) - 1) }
|
||
func (h *knnHeap) pop() knnItem {
|
||
n := len(*h) - 1
|
||
h.swap(0, n)
|
||
v := (*h)[n]
|
||
*h = (*h)[:n]
|
||
h.down(0)
|
||
return v
|
||
}
|
||
func (h *knnHeap) peek() knnItem { return (*h)[0] }
|
||
func (h knnHeap) swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||
func (h *knnHeap) up(i int) {
|
||
for i > 0 {
|
||
p := (i - 1) / 2
|
||
if !h.less(i, p) {
|
||
break
|
||
}
|
||
h.swap(i, p)
|
||
i = p
|
||
}
|
||
}
|
||
func (h *knnHeap) down(i int) {
|
||
for {
|
||
l := 2*i + 1
|
||
r := l + 1
|
||
largest := i
|
||
if l < h.Len() && h.less(l, largest) {
|
||
largest = l
|
||
}
|
||
if r < h.Len() && h.less(r, largest) {
|
||
largest = r
|
||
}
|
||
if largest == i {
|
||
break
|
||
}
|
||
h.swap(i, largest)
|
||
i = largest
|
||
}
|
||
}
|
||
|
||
// gonumKNearest returns indices in ascending distance order.
|
||
func gonumKNearest[T any](backend any, query []float64, k int) ([]int, []float64) {
|
||
b, ok := backend.(*kdBackend)
|
||
if !ok || b.root == nil || len(query) != b.dim || k <= 0 {
|
||
return nil, nil
|
||
}
|
||
var h knnHeap
|
||
bestCap := k
|
||
var search func(*kdNode)
|
||
search = func(n *kdNode) {
|
||
if n == nil {
|
||
return
|
||
}
|
||
c := b.coords(n.idx)
|
||
d := b.metric.Distance(query, c)
|
||
if h.Len() < bestCap {
|
||
h.push(knnItem{idx: n.idx, dist: d})
|
||
} else if d < h.peek().dist {
|
||
// replace max
|
||
h[0] = knnItem{idx: n.idx, dist: d}
|
||
h.down(0)
|
||
}
|
||
axis := n.axis
|
||
qv := query[axis]
|
||
near, far := n.left, n.right
|
||
if qv >= n.val {
|
||
near, far = n.right, n.left
|
||
}
|
||
search(near)
|
||
// prune against current worst in heap if heap is full; otherwise use bestDist
|
||
threshold := math.MaxFloat64
|
||
if h.Len() == bestCap {
|
||
threshold = h.peek().dist
|
||
} else if h.Len() > 0 {
|
||
// use best known (not strictly necessary)
|
||
threshold = h.peek().dist
|
||
}
|
||
diff := qv - n.val
|
||
if diff < 0 {
|
||
diff = -diff
|
||
}
|
||
if diff <= threshold {
|
||
search(far)
|
||
}
|
||
}
|
||
search(b.root)
|
||
// Extract to slices and sort ascending by distance
|
||
res := make([]knnItem, len(h))
|
||
copy(res, h)
|
||
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||
idxs := make([]int, len(res))
|
||
dists := make([]float64, len(res))
|
||
for i := range res {
|
||
idxs[i] = res[i].idx
|
||
dists[i] = res[i].dist
|
||
}
|
||
return idxs, dists
|
||
}
|
||
|
||
func gonumRadius[T any](backend any, query []float64, r float64) ([]int, []float64) {
|
||
b, ok := backend.(*kdBackend)
|
||
if !ok || b.root == nil || len(query) != b.dim || r < 0 {
|
||
return nil, nil
|
||
}
|
||
var res []knnItem
|
||
var search func(*kdNode)
|
||
search = func(n *kdNode) {
|
||
if n == nil {
|
||
return
|
||
}
|
||
c := b.coords(n.idx)
|
||
d := b.metric.Distance(query, c)
|
||
if d <= r {
|
||
res = append(res, knnItem{idx: n.idx, dist: d})
|
||
}
|
||
axis := n.axis
|
||
qv := query[axis]
|
||
near, far := n.left, n.right
|
||
if qv >= n.val {
|
||
near, far = n.right, n.left
|
||
}
|
||
search(near)
|
||
diff := qv - n.val
|
||
if diff < 0 {
|
||
diff = -diff
|
||
}
|
||
if diff <= r {
|
||
search(far)
|
||
}
|
||
}
|
||
search(b.root)
|
||
sort.Slice(res, func(i, j int) bool { return res[i].dist < res[j].dist })
|
||
idxs := make([]int, len(res))
|
||
dists := make([]float64, len(res))
|
||
for i := range res {
|
||
idxs[i] = res[i].idx
|
||
dists[i] = res[i].dist
|
||
}
|
||
return idxs, dists
|
||
}
|