From 899d679518cd943ea9dc4e95424d8ed280574f25 Mon Sep 17 00:00:00 2001 From: Oskar Lappi Date: Tue, 2 Jun 2020 21:30:53 +0300 Subject: [PATCH] Draft of MPI-based reductions acGridReduceScal, acGridReduceVec - Calls acDeviceReduceScal/Vec first - Both functions then perform the same MPI-reduction (MPI_Allreduce) - Not tested --- src/core/device.cc | 70 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/core/device.cc b/src/core/device.cc index f473fc2..9846d2f 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -1620,4 +1620,74 @@ acGridPeriodicBoundconds(const Stream stream) acSyncCommData(sideyz_data); return AC_SUCCESS; } + +AcResult +acMPIReduceScal(AcReal* local_result, const ReductionType rtype, AcReal* result) +{ + + MPI_Op op; + if (rtype == RTYPE_MAX) { + op = MPI_MAX; + } else if (rtype == RTYPE_MIN) { + op = MPI_MIN; + } else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP || rtype == RTYPE_SUM) { + op = MPI_SUM; + } else { + ERROR("Unrecognised rtype"); + } + + #if AC_DOUBLE_PRECISION == 1 + MPI_Datatype datatype = MPI_DOUBLE; + #else + MPI_Datatype datatype = MPI_FLOAT; + #endif + + int world_size; + MPI_Comm_size(MPI_COMM_WORLD, &world_size); + + if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) { + //Overflow risk? + *local_result = *local_result*(*local_result); + } + + AcReal mpi_res; + MPI_Allreduce(&local_result, &mpi_res, 1, datatype, op, MPI_COMM_WORLD); + + if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) { + const AcReal inv_n = AcReal(1.) / world_size; + mpi_res = sqrt(inv_n * mpi_res); + } + *result = mpi_res; + return AC_SUCCESS; + +} + +AcResult +acGridReduceScal(const Device device, const Stream stream, const ReductionType rtype, + const VertexBufferHandle vtxbuf_handle, AcReal* result) +{ + acGridSynchronizeStream(STREAM_ALL); + MPI_Barrier(MPI_COMM_WORLD); + + AcReal local_result; + acDeviceReduceScal(device, stream, rtype, vtxbuf_handle, &local_result); + + return acMPIReduceScal(&local_result,rtype,result); +} + + +AcResult +acGridReduceVec(const Device device, const Stream stream, const ReductionType rtype, + const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1, + const VertexBufferHandle vtxbuf2, AcReal* result) +{ + acGridSynchronizeStream(STREAM_ALL); + MPI_Barrier(MPI_COMM_WORLD); + + AcReal local_result; + acDeviceReduceVec(device, stream, rtype, vtxbuf0, vtxbuf1, vtxbuf2, &local_result); + + return acMPIReduceScal(&local_result,rtype,result); +} + #endif // AC_MPI_ENABLED