From 22e01b7f1d1f98fa383cd4923e6ca3c1079c00f5 Mon Sep 17 00:00:00 2001 From: jpekkila Date: Sun, 19 Apr 2020 23:23:23 +0300 Subject: [PATCH] Rewrote partitioning code --- src/core/device.cc | 178 ++++++++++++++++----------------------------- 1 file changed, 64 insertions(+), 114 deletions(-) diff --git a/src/core/device.cc b/src/core/device.cc index 273834d..4a81246 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -462,18 +462,31 @@ acDeviceReduceVec(const Device device, const Stream stream, const ReductionType #if AC_MPI_ENABLED #include -static int -mod(const int a, const int b) -{ - const int r = a % b; - return r < 0 ? r + b : r; -} #include typedef struct { uint64_t x, y, z; } uint3_64; +static uint3_64 +operator+(const uint3_64& a, const uint3_64& b) +{ + return (uint3_64){a.x + b.x, a.y + b.y, a.z + b.z}; +} + +static int3 +make_int3(const uint3_64 a) +{ + return (int3){(int)a.x, (int)a.y, (int)a.z}; +} + +static uint64_t +mod(const int a, const int b) +{ + const int r = a % b; + return r < 0 ? r + b : r; +} + static uint3_64 morton3D(const uint64_t pid) { @@ -502,118 +515,55 @@ morton1D(const uint3_64 pid) return i; } -int -getPid(int3 pid, const int3 decomposition) +static uint3_64 +decompose(const uint64_t target) { - /* - return mod(pid.x, decomposition.x) + // - mod(pid.y, decomposition.y) * decomposition.x + // - mod(pid.z, decomposition.z) * decomposition.x * decomposition.y; + // This is just so beautifully elegant. Complex and efficient decomposition + // in just one line of code. + uint3_64 p = morton3D(target - 1) + (uint3_64){1, 1, 1}; - */ - pid.x = mod(pid.x, decomposition.x); - pid.y = mod(pid.y, decomposition.y); - pid.z = mod(pid.z, decomposition.z); - - uint64_t i = 0; - for (int bit = 0; bit <= 21; ++bit) { - const uint64_t mask = 0x1l << bit; - i |= (((uint64_t)pid.x & mask) << 0) << 2 * bit; - i |= (((uint64_t)pid.y & mask) << 1) << 2 * bit; - i |= (((uint64_t)pid.z & mask) << 2) << 2 * bit; - } - return (int)i; - + ERRCHK_ALWAYS(p.x * p.y * p.z == target); + return p; } -int3 -getPid3D(const int pid, const int3 decomposition) +static uint3_64 +wrap(const int3 i, const uint3_64 n) { - /* - const int3 pid3d = (int3){ - mod(pid, decomposition.x), - mod(pid / decomposition.x, decomposition.y), - (pid / (decomposition.x * decomposition.y)), - }; - - ERRCHK_ALWAYS(getPid(pid3d, decomposition) == pid); - return pid3d; - */ - uint64_t i, j, k; - i = j = k = 0; - for (int bit = 0; bit <= 21; ++bit) { - const uint64_t mask = 0x1l << 3 * bit; - i |= (((uint64_t)pid & (mask << 0)) >> 2 * bit) >> 0; - j |= (((uint64_t)pid & (mask << 1)) >> 2 * bit) >> 1; - k |= (((uint64_t)pid & (mask << 2)) >> 2 * bit) >> 2; - } - const int3 pid3d = (int3){i, j, k}; - ERRCHK(getPid(pid3d, decomposition) == pid); - return pid3d; - + return (uint3_64){ + mod(i.x, n.x), + mod(i.y, n.y), + mod(i.z, n.z), + }; } -/** Note: assumes that contiguous pids are on the same node and there is one process per GPU. I.e. - * pids are linearly mapped i + j * dx + k * dx * dy. */ +static int +getPid(const int3 pid_raw, const uint3_64 decomp) +{ + const uint3_64 pid = wrap(pid_raw, decomp); + return (int)morton1D(pid); +} + +static int3 +getPid3D(const uint64_t pid, const uint3_64 decomp) +{ + const uint3_64 pid3D = morton3D(pid); + ERRCHK_ALWAYS(getPid(make_int3(pid3D), decomp) == (int)pid); + return (int3){(int)pid3D.x, (int)pid3D.y, (int)pid3D.z}; +} + +/** Assumes that contiguous pids are on the same node and there is one process per GPU. */ static inline bool -onTheSameNode(const int pid_a, const int pid_b) +onTheSameNode(const uint64_t pid_a, const uint64_t pid_b) { int devices_per_node = -1; cudaGetDeviceCount(&devices_per_node); - const int node_a = pid_a / devices_per_node; - const int node_b = pid_b / devices_per_node; + const uint64_t node_a = pid_a / devices_per_node; + const uint64_t node_b = pid_b / devices_per_node; return node_a == node_b; } -static int3 -decompose(const int target) -{ - // This is just so beautifully elegant. Complex and efficient decomposition - // in just one line of code. - uint3_64 p = morton3D(target - 1); - p = (uint3_64){p.x + 1, p.y + 1, p.z + 1}; - - if (p.x * p.y * p.z != target) { - fprintf(stderr, "Invalid number of processes! Cannot decompose the problem domain!\n"); - fprintf(stderr, "Target nprocs: %d. Found: %d\n", target, p.x * p.y * p.z); - ERROR("Invalid nprocs"); - return (int3){-1, -1, -1}; - } - - return (int3){p.x, p.y, p.z}; - /* - if (target == 16) - return (int3){4, 2, 2}; - if (target == 32) - return (int3){4, 4, 2}; - if (target == 128) - return (int3){8, 4, 4}; - if (target == 256) - return (int3){8, 8, 4}; - - int decomposition[] = {1, 1, 1}; - - int axis = 0; - while (decomposition[0] * decomposition[1] * decomposition[2] < target) { - ++decomposition[axis]; - axis = (axis + 1) % 3; - } - - const int found = decomposition[0] * decomposition[1] * decomposition[2]; - if (found != target) { - fprintf(stderr, "Invalid number of processes! Cannot decompose the problem domain!\n"); - fprintf(stderr, "Target nprocs: %d. Next allowed: %d\n", target, found); - ERROR("Invalid nprocs"); - return (int3){-1, -1, -1}; - } - else { - return (int3){decomposition[0], decomposition[1], decomposition[2]}; - } - */ -} - static PackedData acCreatePackedData(const int3 dims) { @@ -746,7 +696,7 @@ acUnpinPackedData(const Device device, const cudaStream_t stream, PackedData* dd // TODO: do with packed data static AcResult -acDeviceDistributeMeshMPI(const AcMesh src, const int3 decomposition, AcMesh* dst) +acDeviceDistributeMeshMPI(const AcMesh src, const uint3_64 decomposition, AcMesh* dst) { MPI_Barrier(MPI_COMM_WORLD); printf("Distributing mesh...\n"); @@ -822,7 +772,7 @@ acDeviceDistributeMeshMPI(const AcMesh src, const int3 decomposition, AcMesh* ds // TODO: do with packed data static AcResult -acDeviceGatherMeshMPI(const AcMesh src, const int3 decomposition, AcMesh* dst) +acDeviceGatherMeshMPI(const AcMesh src, const uint3_64 decomposition, AcMesh* dst) { MPI_Barrier(MPI_COMM_WORLD); printf("Gathering mesh...\n"); @@ -1038,7 +988,7 @@ acTransferCommDataToDevice(const Device device, CommData* data) #endif #if AC_MPI_RT_PINNING -static void +static inline void acPinCommData(const Device device, CommData* data) { cudaSetDevice(device->id); @@ -1077,7 +1027,7 @@ acTransferCommData(const Device device, // int nprocs, pid; MPI_Comm_size(MPI_COMM_WORLD, &nprocs); MPI_Comm_rank(MPI_COMM_WORLD, &pid); - const int3 decomp = decompose(nprocs); + const uint3_64 decomp = decompose(nprocs); const int3 nn = (int3){ device->local_config.int_params[AC_nx], @@ -1230,7 +1180,7 @@ acTransferCommData(const Device device, // int nprocs, pid; MPI_Comm_size(MPI_COMM_WORLD, &nprocs); MPI_Comm_rank(MPI_COMM_WORLD, &pid); - const int3 decomp = decompose(nprocs); + const uint3_64 decomp = decompose(nprocs); const int3 nn = (int3){ device->local_config.int_params[AC_nx], @@ -1363,7 +1313,7 @@ acTransferCommData(const Device device, // int nprocs, pid; MPI_Comm_size(MPI_COMM_WORLD, &nprocs); MPI_Comm_rank(MPI_COMM_WORLD, &pid); - const int3 decomp = decompose(nprocs); + const uint3_64 decomp = decompose(nprocs); const int3 nn = (int3){ device->local_config.int_params[AC_nx], @@ -1467,7 +1417,7 @@ acTransferCommDataWait(const CommData data) typedef struct { Device device; AcMesh submesh; - int3 decomposition; + uint3_64 decomposition; bool initialized; int3 nn; @@ -1508,11 +1458,11 @@ acGridInit(const AcMeshInfo info) printf("Processor %s. Process %d of %d.\n", processor_name, pid, nprocs); // Decompose - AcMeshInfo submesh_info = info; - const int3 decomposition = decompose(nprocs); - const int3 pid3d = getPid3D(pid, decomposition); + AcMeshInfo submesh_info = info; + const uint3_64 decomposition = decompose(nprocs); + const int3 pid3d = getPid3D(pid, decomposition); - printf("Decomposition: %d, %d, %d\n", decomposition.x, decomposition.y, decomposition.z); + printf("Decomposition: %lu, %lu, %lu\n", decomposition.x, decomposition.y, decomposition.z); printf("Process %d: (%d, %d, %d)\n", pid, pid3d.x, pid3d.y, pid3d.z); ERRCHK_ALWAYS(info.int_params[AC_nx] % decomposition.x == 0); ERRCHK_ALWAYS(info.int_params[AC_ny] % decomposition.y == 0); @@ -1650,7 +1600,7 @@ acGridQuit(void) acDestroyCommData(grid.device, &grid.sideyz_data); grid.initialized = false; - grid.decomposition = (int3){-1, -1, -1}; + grid.decomposition = (uint3_64){0, 0, 0}; acMeshDestroy(&grid.submesh); acDeviceDestroy(grid.device);