134 lines
3.2 KiB
C++
134 lines
3.2 KiB
C++
// Copyright (c) 2021 Carl Pearson
|
|
|
|
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <vector>
|
|
|
|
|
|
class KD {
|
|
public:
|
|
struct Point {
|
|
Point() = default;
|
|
Point(int _i, int _j) : i(_i), j(_j) {}
|
|
int i;
|
|
int j;
|
|
static bool by_ij(const Point &a, const Point &b) {
|
|
if (a.i < b.i) {
|
|
return true;
|
|
} else if (a.i > b.i) {
|
|
return false;
|
|
} else {
|
|
return a.j < b.j;
|
|
}
|
|
}
|
|
|
|
static bool by_ji(const Point &a, const Point &b) {
|
|
if (a.j < b.j) {
|
|
return true;
|
|
} else if (a.j > b.j) {
|
|
return false;
|
|
} else {
|
|
return a.i < b.i;
|
|
}
|
|
}
|
|
|
|
};
|
|
|
|
private:
|
|
struct Node {
|
|
Point location;
|
|
Node *left;
|
|
Node *right;
|
|
Node() : left(nullptr), right(nullptr) {}
|
|
~Node() {
|
|
if (left) delete left;
|
|
if (right) delete right;
|
|
}
|
|
};
|
|
|
|
Node *root_;
|
|
|
|
Node *helper(Point *begin, Point *end, int depth) {
|
|
// all work is done within the memory from begin to end
|
|
|
|
if (begin >= end) {return nullptr;}
|
|
|
|
const int iAxis = depth % 2;
|
|
|
|
|
|
if (iAxis) {
|
|
std::sort(begin, end, Point::by_ij);
|
|
} else {
|
|
std::sort(begin, end, Point::by_ji);
|
|
}
|
|
|
|
// split across median
|
|
int mi = (end-begin) / 2;
|
|
|
|
Node *n = new Node;
|
|
|
|
// split points across median
|
|
n->location = begin[mi];
|
|
n->left = helper(begin, begin+mi, depth+1);
|
|
n->right = helper(begin+mi+1, end, depth+1);
|
|
|
|
return n;
|
|
}
|
|
|
|
|
|
|
|
// find all points in [ilb...iub) and [jlb...jub)
|
|
int range_helper(Node *n, int ilb, int iub, int jlb ,int jub, int depth) const {
|
|
|
|
int i = n->location.i;
|
|
int j = n->location.j;
|
|
|
|
int count = 0;
|
|
// node location is in the range
|
|
if (i >= ilb && i < iub && j >= jlb && j < jub) {
|
|
count += 1;
|
|
}
|
|
|
|
|
|
int iAxis = depth % 2;
|
|
|
|
if (iAxis) {
|
|
// TODO: optimization if subtree is totally contained
|
|
if (i >= ilb && n->left) {
|
|
count += range_helper(n->left, ilb, iub, jlb, jub, depth+1);
|
|
}
|
|
if (i < iub && n->right) {
|
|
count += range_helper(n->right, ilb, iub, jlb, jub, depth+1);
|
|
}
|
|
} else {
|
|
// TODO: optimization if subtree is totally contained
|
|
if (j >= jlb && n->left) {
|
|
count += range_helper(n->left, ilb, iub, jlb, jub, depth+1);
|
|
}
|
|
if (j < jub && n->right) {
|
|
count += range_helper(n->right, ilb, iub, jlb, jub, depth+1);
|
|
}
|
|
}
|
|
return count;
|
|
}
|
|
|
|
public:
|
|
KD(const std::vector<Point> &pointlist) {
|
|
std::vector<Point> ps = pointlist; // ctor uses this as scratch space
|
|
root_ = helper(&ps[0], &ps[ps.size()], 0);
|
|
}
|
|
~KD() {
|
|
delete root_;
|
|
}
|
|
|
|
// find all points in [ilb...iub) and [jlb...jub)
|
|
int range_search(int ilb, int iub, int jlb ,int jub) const {
|
|
|
|
return range_helper(root_, ilb, iub, jlb , jub, 0);
|
|
|
|
|
|
}
|
|
|
|
};
|