From ba0bfd65b4f59ee1f19caac0b5257e0519238339 Mon Sep 17 00:00:00 2001 From: jpekkila Date: Wed, 24 Jun 2020 16:10:27 +0300 Subject: [PATCH] Merged the new reduction functions manually --- src/core/device.cc | 79 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/core/device.cc b/src/core/device.cc index 8cda677..ed763a0 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -1698,4 +1698,83 @@ acGridPeriodicBoundconds(const Stream stream) acSyncCommData(sideyz_data); return AC_SUCCESS; } + +static AcResult +acMPIReduceScal(const AcReal local_result, const ReductionType rtype, AcReal* result) +{ + + MPI_Op op; + if (rtype == RTYPE_MAX) { + op = MPI_MAX; + } + else if (rtype == RTYPE_MIN) { + op = MPI_MIN; + } + else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP || rtype == RTYPE_SUM) { + op = MPI_SUM; + } + else { + ERROR("Unrecognised rtype"); + } + +#if AC_DOUBLE_PRECISION == 1 + MPI_Datatype datatype = MPI_DOUBLE; +#else + MPI_Datatype datatype = MPI_FLOAT; +#endif + + int rank; + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + AcReal mpi_res; + MPI_Reduce(&local_result, &mpi_res, 1, datatype, op, 0, MPI_COMM_WORLD); + if (rank == 0) { + if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) { + const AcReal inv_n = AcReal(1.) / + (grid.nn.x * grid.decomposition.x * grid.nn.y * + grid.decomposition.y * grid.nn.z * grid.decomposition.z); + mpi_res = sqrt(inv_n * mpi_res); + } + *result = mpi_res; + } + return AC_SUCCESS; +} + +AcResult +acGridReduceScal(const Stream stream, const ReductionType rtype, + const VertexBufferHandle vtxbuf_handle, AcReal* result) +{ + ERRCHK(grid.initialized); + + const Device device = grid.device; + + acGridSynchronizeStream(STREAM_ALL); + // MPI_Barrier(MPI_COMM_WORLD); + + AcReal local_result; + acDeviceReduceScal(device, stream, rtype, vtxbuf_handle, &local_result); + + return acMPIReduceScal(local_result, rtype, result); +} + +AcResult +acGridReduceVec(const Stream stream, const ReductionType rtype, const VertexBufferHandle vtxbuf0, + const VertexBufferHandle vtxbuf1, const VertexBufferHandle vtxbuf2, AcReal* result) +{ + ERRCHK(grid.initialized); + + const Device device = grid.device; + + acGridSynchronizeStream(STREAM_ALL); + // MPI_Barrier(MPI_COMM_WORLD); + + AcReal local_result; + acDeviceReduceVec(device, stream, rtype, vtxbuf0, vtxbuf1, vtxbuf2, &local_result); + + return acMPIReduceScal(local_result, rtype, result); +} + #endif // AC_MPI_ENABLED