diff --git a/src/core/device.cu b/src/core/device.cu index 9439137..bb7194b 100644 --- a/src/core/device.cu +++ b/src/core/device.cu @@ -799,7 +799,7 @@ mod(const int a, const int b) return r < 0 ? r + b : r; } -static int +static inline int get_neighbor(const int3 offset) { // The number of nodes is n^3 = m = num_processes @@ -810,8 +810,8 @@ get_neighbor(const int3 offset) MPI_Comm_rank(MPI_COMM_WORLD, &pid); MPI_Comm_size(MPI_COMM_WORLD, &num_processes); - const int n = floor(cbrt(num_processes)); - ERRCHK_ALWAYS(ceil(cbrt(num_processes)) == n); + const int n = (int)floor(cbrt((float)num_processes)); + ERRCHK_ALWAYS(ceil(cbrt((float)num_processes)) == n); ERRCHK_ALWAYS(n * n * n == num_processes); return mod(pid + offset.x, n) + offset.y * n + offset.z * n * n;