Create kd.hpp
This commit is contained in:
133
kd.hpp
Normal file
133
kd.hpp
Normal file
@@ -0,0 +1,133 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
class KD {
|
||||
|
||||
struct Point {
|
||||
int i;
|
||||
int j;
|
||||
static bool by_ij(const Point &a, const Point &b) {
|
||||
if (a.i < b.i) {
|
||||
return true;
|
||||
} else if (a.i > b.i) {
|
||||
return false;
|
||||
} else {
|
||||
return a.j < b.j;
|
||||
}
|
||||
}
|
||||
|
||||
static bool by_ji(const Point &a, const Point &b) {
|
||||
if (a.j < b.j) {
|
||||
return true;
|
||||
} else if (a.j > b.j) {
|
||||
return false;
|
||||
} else {
|
||||
return a.i < b.i;
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct Node {
|
||||
Point location;
|
||||
Node *left;
|
||||
Node *right;
|
||||
Node() : left(nullptr), right(nullptr) {}
|
||||
~Node() {
|
||||
if (left) delete left;
|
||||
if (right) delete right;
|
||||
}
|
||||
};
|
||||
|
||||
Node *root_;
|
||||
|
||||
Node *helper(Point *begin, Point *end, int depth) {
|
||||
// all work is done within the memory from begin to end
|
||||
|
||||
if (begin >= end) {return nullptr;}
|
||||
|
||||
const int iAxis = depth % 2;
|
||||
|
||||
|
||||
if (iAxis) {
|
||||
std::sort(begin, end, Point::by_ij);
|
||||
} else {
|
||||
std::sort(begin, end, Point::by_ji);
|
||||
}
|
||||
|
||||
// split across median
|
||||
int mi = (end-begin) / 2;
|
||||
|
||||
Node *n = new Node;
|
||||
|
||||
// split points across median
|
||||
n->location = begin[mi];
|
||||
n->left = helper(begin, begin+mi, depth+1);
|
||||
n->right = helper(begin+mi+1, end, depth+1);
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
|
||||
|
||||
// find all points in [ilb...iub) and [jlb...jub)
|
||||
int range_helper(Node *n, int ilb, int iub, int jlb ,int jub, int depth) const {
|
||||
|
||||
int i = n->location.i;
|
||||
int j = n->location.j;
|
||||
|
||||
int count = 0;
|
||||
// node location is in the range
|
||||
if (i >= ilb && i < iub && j >= jlb && j < jub) {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
|
||||
int iAxis = depth % 2;
|
||||
|
||||
if (iAxis) {
|
||||
// TODO: optimization if subtree is totally contained
|
||||
if (i >= ilb && n->left) {
|
||||
count += range_helper(n->left, ilb, iub, jlb, jub, depth+1);
|
||||
}
|
||||
if (i < iub && n->right) {
|
||||
count += range_helper(n->right, ilb, iub, jlb, jub, depth+1);
|
||||
}
|
||||
} else {
|
||||
// TODO: optimization if subtree is totally contained
|
||||
if (j >= jlb && n->left) {
|
||||
count += range_helper(n->left, ilb, iub, jlb, jub, depth+1);
|
||||
}
|
||||
if (j < jub && n->right) {
|
||||
count += range_helper(n->right, ilb, iub, jlb, jub, depth+1);
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
public:
|
||||
KD(std::vector<Entry> pointlist) {
|
||||
std::vector<Point> ps;
|
||||
for (size_t i = 0; i < pointlist.size(); ++i) {
|
||||
Point p;
|
||||
p.i = pointlist[i].i;
|
||||
p.j = pointlist[i].j;
|
||||
ps.push_back(p);
|
||||
}
|
||||
root_ = helper(&ps[0], &ps[ps.size()], 0);
|
||||
}
|
||||
~KD() {
|
||||
delete root_;
|
||||
}
|
||||
|
||||
// find all points in [ilb...iub) and [jlb...jub)
|
||||
int range_search(int ilb, int iub, int jlb ,int jub) const {
|
||||
|
||||
return range_helper(root_, ilb, iub, jlb , jub, 0);
|
||||
|
||||
|
||||
}
|
||||
|
||||
};
|
Reference in New Issue
Block a user