diff --git a/kd.hpp b/kd.hpp new file mode 100644 index 0000000..dc0733c --- /dev/null +++ b/kd.hpp @@ -0,0 +1,133 @@ +#pragma once + +#include +#include + +class KD { + + struct Point { + int i; + int j; + static bool by_ij(const Point &a, const Point &b) { + if (a.i < b.i) { + return true; + } else if (a.i > b.i) { + return false; + } else { + return a.j < b.j; + } + } + + static bool by_ji(const Point &a, const Point &b) { + if (a.j < b.j) { + return true; + } else if (a.j > b.j) { + return false; + } else { + return a.i < b.i; + } + } + + }; + + struct Node { + Point location; + Node *left; + Node *right; + Node() : left(nullptr), right(nullptr) {} + ~Node() { + if (left) delete left; + if (right) delete right; + } + }; + + Node *root_; + + Node *helper(Point *begin, Point *end, int depth) { + // all work is done within the memory from begin to end + + if (begin >= end) {return nullptr;} + + const int iAxis = depth % 2; + + + if (iAxis) { + std::sort(begin, end, Point::by_ij); + } else { + std::sort(begin, end, Point::by_ji); + } + + // split across median + int mi = (end-begin) / 2; + + Node *n = new Node; + + // split points across median + n->location = begin[mi]; + n->left = helper(begin, begin+mi, depth+1); + n->right = helper(begin+mi+1, end, depth+1); + + return n; + } + + + + // find all points in [ilb...iub) and [jlb...jub) + int range_helper(Node *n, int ilb, int iub, int jlb ,int jub, int depth) const { + + int i = n->location.i; + int j = n->location.j; + + int count = 0; + // node location is in the range + if (i >= ilb && i < iub && j >= jlb && j < jub) { + count += 1; + } + + + int iAxis = depth % 2; + + if (iAxis) { + // TODO: optimization if subtree is totally contained + if (i >= ilb && n->left) { + count += range_helper(n->left, ilb, iub, jlb, jub, depth+1); + } + if (i < iub && n->right) { + count += range_helper(n->right, ilb, iub, jlb, jub, depth+1); + } + } else { + // TODO: optimization if subtree is totally contained + if (j >= jlb && n->left) { + count += range_helper(n->left, ilb, iub, jlb, jub, depth+1); + } + if (j < jub && n->right) { + count += range_helper(n->right, ilb, iub, jlb, jub, depth+1); + } + } + return count; + } + +public: + KD(std::vector pointlist) { + std::vector ps; + for (size_t i = 0; i < pointlist.size(); ++i) { + Point p; + p.i = pointlist[i].i; + p.j = pointlist[i].j; + ps.push_back(p); + } + root_ = helper(&ps[0], &ps[ps.size()], 0); + } + ~KD() { + delete root_; + } + + // find all points in [ilb...iub) and [jlb...jub) + int range_search(int ilb, int iub, int jlb ,int jub) const { + + return range_helper(root_, ilb, iub, jlb , jub, 0); + + + } + +};