Draft of MPI-based reductions acGridReduceScal, acGridReduceVec
- Calls acDeviceReduceScal/Vec first - Both functions then perform the same MPI-reduction (MPI_Allreduce) - Not tested
This commit is contained in:
@@ -1620,4 +1620,74 @@ acGridPeriodicBoundconds(const Stream stream)
|
|||||||
acSyncCommData(sideyz_data);
|
acSyncCommData(sideyz_data);
|
||||||
return AC_SUCCESS;
|
return AC_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AcResult
|
||||||
|
acMPIReduceScal(AcReal* local_result, const ReductionType rtype, AcReal* result)
|
||||||
|
{
|
||||||
|
|
||||||
|
MPI_Op op;
|
||||||
|
if (rtype == RTYPE_MAX) {
|
||||||
|
op = MPI_MAX;
|
||||||
|
} else if (rtype == RTYPE_MIN) {
|
||||||
|
op = MPI_MIN;
|
||||||
|
} else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP || rtype == RTYPE_SUM) {
|
||||||
|
op = MPI_SUM;
|
||||||
|
} else {
|
||||||
|
ERROR("Unrecognised rtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
#if AC_DOUBLE_PRECISION == 1
|
||||||
|
MPI_Datatype datatype = MPI_DOUBLE;
|
||||||
|
#else
|
||||||
|
MPI_Datatype datatype = MPI_FLOAT;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int world_size;
|
||||||
|
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||||
|
|
||||||
|
if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) {
|
||||||
|
//Overflow risk?
|
||||||
|
*local_result = *local_result*(*local_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
AcReal mpi_res;
|
||||||
|
MPI_Allreduce(&local_result, &mpi_res, 1, datatype, op, MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) {
|
||||||
|
const AcReal inv_n = AcReal(1.) / world_size;
|
||||||
|
mpi_res = sqrt(inv_n * mpi_res);
|
||||||
|
}
|
||||||
|
*result = mpi_res;
|
||||||
|
return AC_SUCCESS;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
AcResult
|
||||||
|
acGridReduceScal(const Device device, const Stream stream, const ReductionType rtype,
|
||||||
|
const VertexBufferHandle vtxbuf_handle, AcReal* result)
|
||||||
|
{
|
||||||
|
acGridSynchronizeStream(STREAM_ALL);
|
||||||
|
MPI_Barrier(MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
AcReal local_result;
|
||||||
|
acDeviceReduceScal(device, stream, rtype, vtxbuf_handle, &local_result);
|
||||||
|
|
||||||
|
return acMPIReduceScal(&local_result,rtype,result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
AcResult
|
||||||
|
acGridReduceVec(const Device device, const Stream stream, const ReductionType rtype,
|
||||||
|
const VertexBufferHandle vtxbuf0, const VertexBufferHandle vtxbuf1,
|
||||||
|
const VertexBufferHandle vtxbuf2, AcReal* result)
|
||||||
|
{
|
||||||
|
acGridSynchronizeStream(STREAM_ALL);
|
||||||
|
MPI_Barrier(MPI_COMM_WORLD);
|
||||||
|
|
||||||
|
AcReal local_result;
|
||||||
|
acDeviceReduceVec(device, stream, rtype, vtxbuf0, vtxbuf1, vtxbuf2, &local_result);
|
||||||
|
|
||||||
|
return acMPIReduceScal(&local_result,rtype,result);
|
||||||
|
}
|
||||||
|
|
||||||
#endif // AC_MPI_ENABLED
|
#endif // AC_MPI_ENABLED
|
||||||
|
Reference in New Issue
Block a user