Added WIP version of the new bidirectional comm scheme
This commit is contained in:
@@ -1166,6 +1166,130 @@ acTransferCommDataWait(const CommData data)
|
||||
// NOP
|
||||
}
|
||||
|
||||
#elif AC_MPI_BIDIRECTIONAL_SCHEME
|
||||
|
||||
static int3
|
||||
mod(const int3 a, const int3 n)
|
||||
{
|
||||
return (int3){mod(a.x, n.x), mod(a.y, n.y), mod(a.z, n.z)};
|
||||
}
|
||||
|
||||
static AcResult
|
||||
acTransferCommData(const Device device, //
|
||||
const int3* a0s, // Src idx inside comp. domain
|
||||
const int3* b0s, // Dst idx inside bound zone
|
||||
CommData* data)
|
||||
{
|
||||
cudaSetDevice(device->id);
|
||||
|
||||
MPI_Datatype datatype = MPI_FLOAT;
|
||||
if (sizeof(AcReal) == 8)
|
||||
datatype = MPI_DOUBLE;
|
||||
|
||||
int nprocs, pid;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &nprocs);
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &pid);
|
||||
const uint3_64 decomp = decompose(nprocs);
|
||||
|
||||
const int3 nn = (int3){
|
||||
device->local_config.int_params[AC_nx],
|
||||
device->local_config.int_params[AC_ny],
|
||||
device->local_config.int_params[AC_nz],
|
||||
};
|
||||
|
||||
const int3 mm = (int3){
|
||||
device->local_config.int_params[AC_mx],
|
||||
device->local_config.int_params[AC_my],
|
||||
device->local_config.int_params[AC_mz],
|
||||
};
|
||||
|
||||
const int3 dims = data->dims;
|
||||
const size_t num_blocks = data->count;
|
||||
|
||||
cudaDeviceSynchronize(); // TODO debug REMOVE
|
||||
for (size_t b0_idx = 0; b0_idx < num_blocks; ++b0_idx) {
|
||||
const int3 b0 = b0s[b0_idx];
|
||||
const int3 nghost = (int3){NGHOST, NGHOST, NGHOST};
|
||||
const int3 a0 = mod(((b0 - nghost) + nn), nn) + nghost;
|
||||
|
||||
size_t a0_idx = -1;
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
if (a0s[i].x == a0.x && a0s[i].y == a0.y && a0s[i].z == a0.z) {
|
||||
a0_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ERRCHK_ALWAYS(a0_idx < num_blocks); // TODO debug REMOVE
|
||||
|
||||
const int3 neighbor = (int3){
|
||||
a0.x < b0.x ? -1 : a0.x > b0.x ? 1 : 0,
|
||||
a0.y < b0.y ? -1 : a0.y > b0.y ? 1 : 0,
|
||||
a0.z < b0.z ? -1 : a0.z > b0.z ? 1 : 0,
|
||||
};
|
||||
|
||||
const int3 b1 = (int3){
|
||||
neighbor.x < 0 ? a0.x - nghost.x : neighbor.x > 0 ? a0.x + nghost.x : a0.x,
|
||||
neighbor.y < 0 ? a0.y - nghost.y : neighbor.y > 0 ? a0.y + nghost.y : a0.y,
|
||||
neighbor.z < 0 ? a0.z - nghost.z : neighbor.z > 0 ? a0.z + nghost.z : a0.z,
|
||||
};
|
||||
|
||||
size_t b1_idx = -1;
|
||||
for (size_t i = 0; i < num_blocks; ++i) {
|
||||
if (b0s[i].x == b1.x && b0s[i].y == b1.y && b0s[i].z == b1.z) {
|
||||
b1_idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
ERRCHK_ALWAYS(b1_idx < num_blocks); // TODO debug REMOVE
|
||||
|
||||
const int3 pid3d = getPid3D(pid, decomp);
|
||||
const int npid = getPid(pid3d + neighbor, decomp);
|
||||
const size_t count = dims.x * dims.y * dims.z * NUM_VTXBUF_HANDLES;
|
||||
|
||||
PackedData* src = &data->srcs[a0_idx];
|
||||
PackedData* dst = &data->dsts[b1_idx];
|
||||
|
||||
MPI_Irecv(dst->data, count, datatype, npid, b1_idx, MPI_COMM_WORLD,
|
||||
&data->recv_reqs[b1_idx]);
|
||||
MPI_Isend(src->data, count, datatype, npid, b0_idx, MPI_COMM_WORLD,
|
||||
&data->send_reqs[b0_idx]);
|
||||
|
||||
/*
|
||||
const int3 neighbor = (int3){
|
||||
a0.x < b0.x ? a0.x - nghost.x : a0.x > b0.x ? a0.x + nghost.x : a0.x,
|
||||
a0.y < b0.y ? a0.y - nghost.y : a0.y > b0.y ? a0.y + nghost.y : a0.y,
|
||||
a0.z < b0.z ? a0.z - nghost.z : a0.z > b0.z ? a0.z + nghost.z : a0.z,
|
||||
};*/
|
||||
|
||||
printf("a0 -> b0: (%d, %d, %d) -> (%d, %d, %d)\n", a0.x, a0.y, a0.z, b0.x, b0.y, b0.z);
|
||||
printf("b1: (%d, %d, %d)\n", b1.x, b1.y, b1.z);
|
||||
printf("neighbor: (%d, %d, %d)\n", neighbor.x, neighbor.y, neighbor.z);
|
||||
|
||||
/*
|
||||
const int3 b1 = (int3){
|
||||
a0.x < b0.x ? a0.x - nghost.x : a0.x > b0.x ? a0.x + nghost.x : a0.x,
|
||||
a0.y < b0.y ? a0.y - nghost.y : a0.y > b0.y ? a0.y + nghost.y : a0.y,
|
||||
a0.z < b0.z ? a0.z - nghost.z : a0.z > b0.z ? a0.z + nghost.z : a0.z,
|
||||
};
|
||||
const int3 a1 = mod(((b1 - nghost) + nn), nn) + nghost;
|
||||
|
||||
printf("b0, a0: (%d, %d, %d) -> (%d, %d, %d)\n", b0.x, b0.y, b0.z, a0.x, a0.y, a0.z);
|
||||
printf("b1, a1: (%d, %d, %d) -> (%d, %d, %d)\n\n", b1.x, b1.y, b1.z, a1.x, a1.y, a1.z);
|
||||
*/
|
||||
}
|
||||
|
||||
return AC_SUCCESS;
|
||||
}
|
||||
|
||||
static void
|
||||
acTransferCommDataWait(const CommData data)
|
||||
{
|
||||
for (size_t i = 0; i < data.count; ++i) {
|
||||
MPI_Wait(&data.send_reqs[i], MPI_STATUS_IGNORE);
|
||||
MPI_Wait(&data.recv_reqs[i], MPI_STATUS_IGNORE);
|
||||
}
|
||||
}
|
||||
|
||||
#elif AC_MPI_RT_PINNING
|
||||
static AcResult
|
||||
acTransferCommData(const Device device, //
|
||||
@@ -1916,6 +2040,7 @@ acGridPeriodicBoundconds(const Stream stream)
|
||||
(int3){NGHOST, nn.y, nn.z}, //
|
||||
(int3){nn.x, nn.y, nn.z},
|
||||
};
|
||||
/*
|
||||
const int3 corner_b0s[] = {
|
||||
(int3){0, 0, 0},
|
||||
(int3){NGHOST + nn.x, 0, 0},
|
||||
@@ -1927,6 +2052,18 @@ acGridPeriodicBoundconds(const Stream stream)
|
||||
(int3){0, NGHOST + nn.y, NGHOST + nn.z},
|
||||
(int3){NGHOST + nn.x, NGHOST + nn.y, NGHOST + nn.z},
|
||||
};
|
||||
*/
|
||||
const int3 corner_b0s[] = {
|
||||
(int3){0, 0, 0},
|
||||
(int3){NGHOST + nn.x, 0, 0},
|
||||
(int3){0, NGHOST + nn.y, 0},
|
||||
(int3){0, 0, NGHOST + nn.z},
|
||||
|
||||
(int3){NGHOST + nn.x, NGHOST + nn.y, 0},
|
||||
(int3){NGHOST + nn.x, 0, NGHOST + nn.z},
|
||||
(int3){0, NGHOST + nn.y, NGHOST + nn.z},
|
||||
(int3){NGHOST + nn.x, NGHOST + nn.y, NGHOST + nn.z},
|
||||
};
|
||||
|
||||
// Edges X
|
||||
const int3 edgex_a0s[] = {
|
||||
|
Reference in New Issue
Block a user