Better concurrency and some simplifications (MPI).

This commit is contained in:
jpekkila
2020-01-20 18:45:24 +02:00
parent 765ce9a573
commit 993bfc4533

View File

@@ -1129,6 +1129,7 @@ typedef struct {
int3 dims;
size_t count;
cudaStream_t* streams;
MPI_Request* send_reqs;
MPI_Request* recv_reqs;
} CommData;
@@ -1147,6 +1148,7 @@ acCreateCommData(const Device device, const int3 dims, const size_t count)
data.dims = dims;
data.count = count;
data.streams = (cudaStream_t*)malloc(count * sizeof(cudaStream_t));
data.send_reqs = (MPI_Request*)malloc(count * sizeof(MPI_Request));
data.recv_reqs = (MPI_Request*)malloc(count * sizeof(MPI_Request));
@@ -1162,6 +1164,8 @@ acCreateCommData(const Device device, const int3 dims, const size_t count)
data.dsts[i] = acCreatePackedData(dims);
data.srcs_host[i] = acCreatePackedDataHost(dims);
data.dsts_host[i] = acCreatePackedDataHost(dims);
cudaStreamCreate(&data.streams[i]);
}
return data;
@@ -1177,6 +1181,8 @@ acDestroyCommData(const Device device, CommData* data)
acDestroyPackedData(&data->dsts[i]);
acDestroyPackedDataHost(&data->srcs_host[i]);
acDestroyPackedDataHost(&data->dsts_host[i]);
cudaStreamDestroy(data->streams[i]);
}
free(data->srcs);
@@ -1184,6 +1190,7 @@ acDestroyCommData(const Device device, CommData* data)
free(data->srcs_host);
free(data->dsts_host);
free(data->streams);
free(data->send_reqs);
free(data->recv_reqs);
@@ -1192,41 +1199,36 @@ acDestroyCommData(const Device device, CommData* data)
}
static void
acPackCommData(const Device device, const int3* a0s, const size_t count, CommData* data)
acPackCommData(const Device device, const int3* a0s, CommData* data)
{
cudaSetDevice(device->id);
cudaStream_t streams[count];
for (size_t i = 0; i < count; ++i)
cudaStreamCreate(&streams[i]);
for (size_t i = 0; i < count; ++i)
acKernelPackData(streams[i], device->vba, a0s[i], data->srcs[i]);
for (size_t i = 0; i < count; ++i)
acTransferPackedDataToHost(device, streams[i], data->srcs[i], &data->srcs_host[i]);
for (size_t i = 0; i < count; ++i)
cudaStreamDestroy(streams[i]);
for (size_t i = 0; i < data->count; ++i)
acKernelPackData(data->streams[i], device->vba, a0s[i], data->srcs[i]);
}
static void
acUnpackCommData(const Device device, const int3* b0s, const size_t count, CommData* data)
acTransferCommDataToHost(const Device device, CommData* data)
{
cudaSetDevice(device->id);
for (size_t i = 0; i < data->count; ++i)
acTransferPackedDataToHost(device, data->streams[i], data->srcs[i], &data->srcs_host[i]);
}
static void
acUnpackCommData(const Device device, const int3* b0s, CommData* data)
{
cudaSetDevice(device->id);
cudaStream_t streams[count];
for (size_t i = 0; i < count; ++i)
cudaStreamCreate(&streams[i]);
for (size_t i = 0; i < data->count; ++i)
acKernelUnpackData(data->streams[i], data->dsts[i], b0s[i], device->vba);
}
for (size_t i = 0; i < count; ++i)
acTransferPackedDataToDevice(device, streams[i], data->dsts_host[i], &data->dsts[i]);
for (size_t i = 0; i < count; ++i)
acKernelUnpackData(streams[i], data->dsts[i], b0s[i], device->vba);
for (size_t i = 0; i < count; ++i)
cudaStreamDestroy(streams[i]);
static void
acTransferCommDataToDevice(const Device device, CommData* data)
{
cudaSetDevice(device->id);
for (size_t i = 0; i < data->count; ++i)
acTransferPackedDataToDevice(device, data->streams[i], data->dsts_host[i], &data->dsts[i]);
}
static AcResult
@@ -1277,17 +1279,14 @@ acTransferCommData(const Device device, //
// PackedData src = data->srcs[a_idx];
// PackedData dst = data->dsts[b_idx];
PackedData src = data->srcs_host[a_idx];
// PackedData src = data->srcs_host[a_idx];
PackedData dst = data->dsts_host[b_idx];
const int3 pid3d = getPid3D(pid, decomp);
MPI_Request send_req, recv_req;
MPI_Isend(src.data, count, datatype, getPid(pid3d + neighbor, decomp),
b_idx, MPI_COMM_WORLD, &send_req);
MPI_Request recv_req;
MPI_Irecv(dst.data, count, datatype, getPid(pid3d - neighbor, decomp),
b_idx, MPI_COMM_WORLD, &recv_req);
data->send_reqs[b_idx] = send_req;
data->recv_reqs[b_idx] = recv_req;
}
}
@@ -1296,6 +1295,46 @@ acTransferCommData(const Device device, //
}
}
for (int k = -1; k <= 1; ++k) {
for (int j = -1; j <= 1; ++j) {
for (int i = -1; i <= 1; ++i) {
if (i == 0 && j == 0 && k == 0)
continue;
for (size_t a_idx = 0; a_idx < blockcount; ++a_idx) {
for (size_t b_idx = 0; b_idx < blockcount; ++b_idx) {
const int3 neighbor = (int3){i, j, k};
const int3 a0 = a0s[a_idx];
// const int3 a1 = a0 + dims;
const int3 b0 = a0 - neighbor * nn;
// const int3 b1 = a1 - neighbor * nn;
if (b0s[b_idx].x == b0.x && b0s[b_idx].y == b0.y && b0s[b_idx].z == b0.z) {
const size_t count = dims.x * dims.y * dims.z * NUM_VTXBUF_HANDLES;
// PackedData src = data->srcs[a_idx];
// PackedData dst = data->dsts[b_idx];
PackedData src = data->srcs_host[a_idx];
// PackedData dst = data->dsts_host[b_idx];
const int3 pid3d = getPid3D(pid, decomp);
MPI_Request send_req;
cudaStreamSynchronize(data->streams[a_idx]);
MPI_Isend(src.data, count, datatype, getPid(pid3d + neighbor, decomp),
b_idx, MPI_COMM_WORLD, &send_req);
data->send_reqs[b_idx] = send_req;
}
}
}
}
}
}
return AC_SUCCESS;
}
@@ -1437,16 +1476,78 @@ acDeviceCommunicateHalosMPIAlt(const Device device)
CommData sidexz_data = acCreateCommData(device, sidexz_dims, ARRAY_SIZE(sidexz_a0s));
CommData sideyz_data = acCreateCommData(device, sideyz_dims, ARRAY_SIZE(sideyz_a0s));
// Warmup
for (int i = 0; i < 10; ++i) {
acPackCommData(device, corner_a0s, &corner_data);
acPackCommData(device, edgex_a0s, &edgex_data);
acPackCommData(device, edgey_a0s, &edgey_data);
acPackCommData(device, edgez_a0s, &edgez_data);
acPackCommData(device, sidexy_a0s, &sidexy_data);
acPackCommData(device, sidexz_a0s, &sidexz_data);
acPackCommData(device, sideyz_a0s, &sideyz_data);
acTransferCommDataToHost(device, &corner_data);
acTransferCommDataToHost(device, &edgex_data);
acTransferCommDataToHost(device, &edgey_data);
acTransferCommDataToHost(device, &edgez_data);
acTransferCommDataToHost(device, &sidexy_data);
acTransferCommDataToHost(device, &sidexz_data);
acTransferCommDataToHost(device, &sideyz_data);
acTransferCommData(device, corner_a0s, corner_b0s, &corner_data);
acTransferCommData(device, edgex_a0s, edgex_b0s, &edgex_data);
acTransferCommData(device, edgey_a0s, edgey_b0s, &edgey_data);
acTransferCommData(device, edgez_a0s, edgez_b0s, &edgez_data);
acTransferCommData(device, sidexy_a0s, sidexy_b0s, &sidexy_data);
acTransferCommData(device, sidexz_a0s, sidexz_b0s, &sidexz_data);
acTransferCommData(device, sideyz_a0s, sideyz_b0s, &sideyz_data);
acTransferCommDataWait(corner_data);
acTransferCommDataWait(edgex_data);
acTransferCommDataWait(edgey_data);
acTransferCommDataWait(edgez_data);
acTransferCommDataWait(sidexy_data);
acTransferCommDataWait(sidexz_data);
acTransferCommDataWait(sideyz_data);
acTransferCommDataToDevice(device, &corner_data);
acTransferCommDataToDevice(device, &edgex_data);
acTransferCommDataToDevice(device, &edgey_data);
acTransferCommDataToDevice(device, &edgez_data);
acTransferCommDataToDevice(device, &sidexy_data);
acTransferCommDataToDevice(device, &sidexz_data);
acTransferCommDataToDevice(device, &sideyz_data);
acUnpackCommData(device, corner_b0s, &corner_data);
acUnpackCommData(device, edgex_b0s, &edgex_data);
acUnpackCommData(device, edgey_b0s, &edgey_data);
acUnpackCommData(device, edgez_b0s, &edgez_data);
acUnpackCommData(device, sidexy_b0s, &sidexy_data);
acUnpackCommData(device, sidexz_b0s, &sidexz_data);
acUnpackCommData(device, sideyz_b0s, &sideyz_data);
}
// Communicate
Timer ttot;
cudaDeviceSynchronize();
MPI_Barrier(MPI_COMM_WORLD);
timer_reset(&ttot);
acPackCommData(device, corner_a0s, ARRAY_SIZE(corner_a0s), &corner_data);
acPackCommData(device, edgex_a0s, ARRAY_SIZE(edgex_a0s), &edgex_data);
acPackCommData(device, edgey_a0s, ARRAY_SIZE(edgey_a0s), &edgey_data);
acPackCommData(device, edgez_a0s, ARRAY_SIZE(edgez_a0s), &edgez_data);
acPackCommData(device, sidexy_a0s, ARRAY_SIZE(sidexy_a0s), &sidexy_data);
acPackCommData(device, sidexz_a0s, ARRAY_SIZE(sidexz_a0s), &sidexz_data);
acPackCommData(device, sideyz_a0s, ARRAY_SIZE(sideyz_a0s), &sideyz_data);
acPackCommData(device, corner_a0s, &corner_data);
acPackCommData(device, edgex_a0s, &edgex_data);
acPackCommData(device, edgey_a0s, &edgey_data);
acPackCommData(device, edgez_a0s, &edgez_data);
acPackCommData(device, sidexy_a0s, &sidexy_data);
acPackCommData(device, sidexz_a0s, &sidexz_data);
acPackCommData(device, sideyz_a0s, &sideyz_data);
acTransferCommDataToHost(device, &corner_data);
acTransferCommDataToHost(device, &edgex_data);
acTransferCommDataToHost(device, &edgey_data);
acTransferCommDataToHost(device, &edgez_data);
acTransferCommDataToHost(device, &sidexy_data);
acTransferCommDataToHost(device, &sidexz_data);
acTransferCommDataToHost(device, &sideyz_data);
acTransferCommData(device, corner_a0s, corner_b0s, &corner_data);
acTransferCommData(device, edgex_a0s, edgex_b0s, &edgex_data);
@@ -1464,17 +1565,82 @@ acDeviceCommunicateHalosMPIAlt(const Device device)
acTransferCommDataWait(sidexz_data);
acTransferCommDataWait(sideyz_data);
acUnpackCommData(device, corner_b0s, ARRAY_SIZE(corner_b0s), &corner_data);
acUnpackCommData(device, edgex_b0s, ARRAY_SIZE(edgex_b0s), &edgex_data);
acUnpackCommData(device, edgey_b0s, ARRAY_SIZE(edgey_b0s), &edgey_data);
acUnpackCommData(device, edgez_b0s, ARRAY_SIZE(edgez_b0s), &edgez_data);
acUnpackCommData(device, sidexy_b0s, ARRAY_SIZE(sidexy_b0s), &sidexy_data);
acUnpackCommData(device, sidexz_b0s, ARRAY_SIZE(sidexz_b0s), &sidexz_data);
acUnpackCommData(device, sideyz_b0s, ARRAY_SIZE(sideyz_b0s), &sideyz_data);
acTransferCommDataToDevice(device, &corner_data);
acTransferCommDataToDevice(device, &edgex_data);
acTransferCommDataToDevice(device, &edgey_data);
acTransferCommDataToDevice(device, &edgez_data);
acTransferCommDataToDevice(device, &sidexy_data);
acTransferCommDataToDevice(device, &sidexz_data);
acTransferCommDataToDevice(device, &sideyz_data);
printf("Total: ");
timer_diff_print(ttot);
acUnpackCommData(device, corner_b0s, &corner_data);
acUnpackCommData(device, edgex_b0s, &edgex_data);
acUnpackCommData(device, edgey_b0s, &edgey_data);
acUnpackCommData(device, edgez_b0s, &edgez_data);
acUnpackCommData(device, sidexy_b0s, &sidexy_data);
acUnpackCommData(device, sidexz_b0s, &sidexz_data);
acUnpackCommData(device, sideyz_b0s, &sideyz_data);
/*
acPackCommData(device, corner_a0s, &corner_data);
acPackCommData(device, edgex_a0s, &edgex_data);
acPackCommData(device, edgey_a0s, &edgey_data);
acPackCommData(device, edgez_a0s, &edgez_data);
acPackCommData(device, sidexy_a0s, &sidexy_data);
acPackCommData(device, sidexz_a0s, &sidexz_data);
acPackCommData(device, sideyz_a0s, &sideyz_data);
acTransferCommDataToHost(device, &corner_data);
acTransferCommDataToHost(device, &edgex_data);
acTransferCommDataToHost(device, &edgey_data);
acTransferCommDataToHost(device, &edgez_data);
acTransferCommDataToHost(device, &sidexy_data);
acTransferCommDataToHost(device, &sidexz_data);
acTransferCommDataToHost(device, &sideyz_data);
acTransferCommData(device, corner_a0s, corner_b0s, &corner_data);
acTransferCommData(device, edgex_a0s, edgex_b0s, &edgex_data);
acTransferCommData(device, edgey_a0s, edgey_b0s, &edgey_data);
acTransferCommData(device, edgez_a0s, edgez_b0s, &edgez_data);
acTransferCommData(device, sidexy_a0s, sidexy_b0s, &sidexy_data);
acTransferCommData(device, sidexz_a0s, sidexz_b0s, &sidexz_data);
acTransferCommData(device, sideyz_a0s, sideyz_b0s, &sideyz_data);
acTransferCommDataWait(corner_data);
acTransferCommDataWait(edgex_data);
acTransferCommDataWait(edgey_data);
acTransferCommDataWait(edgez_data);
acTransferCommDataWait(sidexy_data);
acTransferCommDataWait(sidexz_data);
acTransferCommDataWait(sideyz_data);
acTransferCommDataToDevice(device, &corner_data);
acTransferCommDataToDevice(device, &edgex_data);
acTransferCommDataToDevice(device, &edgey_data);
acTransferCommDataToDevice(device, &edgez_data);
acTransferCommDataToDevice(device, &sidexy_data);
acTransferCommDataToDevice(device, &sidexz_data);
acTransferCommDataToDevice(device, &sideyz_data);
acUnpackCommData(device, corner_b0s, &corner_data);
acUnpackCommData(device, edgex_b0s, &edgex_data);
acUnpackCommData(device, edgey_b0s, &edgey_data);
acUnpackCommData(device, edgez_b0s, &edgez_data);
acUnpackCommData(device, sidexy_b0s, &sidexy_data);
acUnpackCommData(device, sidexz_b0s, &sidexz_data);
acUnpackCommData(device, sideyz_b0s, &sideyz_data);
*/
cudaDeviceSynchronize();
MPI_Barrier(MPI_COMM_WORLD);
int pid;
MPI_Comm_rank(MPI_COMM_WORLD, &pid);
if (!pid) {
printf("---------------------------Total: ");
timer_diff_print(ttot);
}
// Dealloc
acDestroyCommData(device, &corner_data);
acDestroyCommData(device, &edgex_data);
acDestroyCommData(device, &edgey_data);