diff --git a/src/core/device.cc b/src/core/device.cc index 8ce057f..6ac849e 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -468,29 +468,62 @@ mod(const int a, const int b) const int r = a % b; return r < 0 ? r + b : r; } +#include -static int -getPid(const int3 pid, const int3 decomposition) +int +getPid(int3 pid, const int3 decomposition) { - return mod(pid.x, decomposition.x) + // - mod(pid.y, decomposition.y) * decomposition.x + // - mod(pid.z, decomposition.z) * decomposition.x * decomposition.y; + /* + return mod(pid.x, decomposition.x) + // + mod(pid.y, decomposition.y) * decomposition.x + // + mod(pid.z, decomposition.z) * decomposition.x * decomposition.y; + + */ + pid.x = mod(pid.x, decomposition.x); + pid.y = mod(pid.y, decomposition.y); + pid.z = mod(pid.z, decomposition.z); + + uint64_t i = 0; + for (int bit = 0; bit <= 21; ++bit) { + const uint64_t mask = 0x1l << bit; + i |= (((uint64_t)pid.x & mask) << 0) << 2 * bit; + i |= (((uint64_t)pid.y & mask) << 1) << 2 * bit; + i |= (((uint64_t)pid.z & mask) << 2) << 2 * bit; + } + return (int)i; + } -static int3 +int3 getPid3D(const int pid, const int3 decomposition) { - const int3 pid3d = (int3){ - mod(pid, decomposition.x), - mod(pid / decomposition.x, decomposition.y), - (pid / (decomposition.x * decomposition.y)), - }; + /* + const int3 pid3d = (int3){ + mod(pid, decomposition.x), + mod(pid / decomposition.x, decomposition.y), + (pid / (decomposition.x * decomposition.y)), + }; + + ERRCHK_ALWAYS(getPid(pid3d, decomposition) == pid); + return pid3d; + */ + uint64_t i, j, k; + i = j = k = 0; + for (int bit = 0; bit <= 21; ++bit) { + const uint64_t mask = 0x1l << 3 * bit; + i |= (((uint64_t)pid & (mask << 0)) >> 2 * bit) >> 0; + j |= (((uint64_t)pid & (mask << 1)) >> 2 * bit) >> 1; + k |= (((uint64_t)pid & (mask << 2)) >> 2 * bit) >> 2; + } + const int3 pid3d = (int3){i, j, k}; + ERRCHK(getPid(pid3d, decomposition) == pid); return pid3d; + } /** Note: assumes that contiguous pids are on the same node and there is one process per GPU. I.e. * pids are linearly mapped i + j * dx + k * dx * dy. */ -static bool +static inline bool onTheSameNode(const int pid_a, const int pid_b) { int devices_per_node = -1;