Merged the new reduction functions manually
This commit is contained in:
@@ -1698,4 +1698,83 @@ acGridPeriodicBoundconds(const Stream stream)
|
||||
acSyncCommData(sideyz_data);
|
||||
return AC_SUCCESS;
|
||||
}
|
||||
|
||||
static AcResult
|
||||
acMPIReduceScal(const AcReal local_result, const ReductionType rtype, AcReal* result)
|
||||
{
|
||||
|
||||
MPI_Op op;
|
||||
if (rtype == RTYPE_MAX) {
|
||||
op = MPI_MAX;
|
||||
}
|
||||
else if (rtype == RTYPE_MIN) {
|
||||
op = MPI_MIN;
|
||||
}
|
||||
else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP || rtype == RTYPE_SUM) {
|
||||
op = MPI_SUM;
|
||||
}
|
||||
else {
|
||||
ERROR("Unrecognised rtype");
|
||||
}
|
||||
|
||||
#if AC_DOUBLE_PRECISION == 1
|
||||
MPI_Datatype datatype = MPI_DOUBLE;
|
||||
#else
|
||||
MPI_Datatype datatype = MPI_FLOAT;
|
||||
#endif
|
||||
|
||||
int rank;
|
||||
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
||||
|
||||
int world_size;
|
||||
MPI_Comm_size(MPI_COMM_WORLD, &world_size);
|
||||
|
||||
AcReal mpi_res;
|
||||
MPI_Reduce(&local_result, &mpi_res, 1, datatype, op, 0, MPI_COMM_WORLD);
|
||||
if (rank == 0) {
|
||||
if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) {
|
||||
const AcReal inv_n = AcReal(1.) /
|
||||
(grid.nn.x * grid.decomposition.x * grid.nn.y *
|
||||
grid.decomposition.y * grid.nn.z * grid.decomposition.z);
|
||||
mpi_res = sqrt(inv_n * mpi_res);
|
||||
}
|
||||
*result = mpi_res;
|
||||
}
|
||||
return AC_SUCCESS;
|
||||
}
|
||||
|
||||
AcResult
|
||||
acGridReduceScal(const Stream stream, const ReductionType rtype,
|
||||
const VertexBufferHandle vtxbuf_handle, AcReal* result)
|
||||
{
|
||||
ERRCHK(grid.initialized);
|
||||
|
||||
const Device device = grid.device;
|
||||
|
||||
acGridSynchronizeStream(STREAM_ALL);
|
||||
// MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
AcReal local_result;
|
||||
acDeviceReduceScal(device, stream, rtype, vtxbuf_handle, &local_result);
|
||||
|
||||
return acMPIReduceScal(local_result, rtype, result);
|
||||
}
|
||||
|
||||
AcResult
|
||||
acGridReduceVec(const Stream stream, const ReductionType rtype, const VertexBufferHandle vtxbuf0,
|
||||
const VertexBufferHandle vtxbuf1, const VertexBufferHandle vtxbuf2, AcReal* result)
|
||||
{
|
||||
ERRCHK(grid.initialized);
|
||||
|
||||
const Device device = grid.device;
|
||||
|
||||
acGridSynchronizeStream(STREAM_ALL);
|
||||
// MPI_Barrier(MPI_COMM_WORLD);
|
||||
|
||||
AcReal local_result;
|
||||
acDeviceReduceVec(device, stream, rtype, vtxbuf0, vtxbuf1, vtxbuf2, &local_result);
|
||||
|
||||
return acMPIReduceScal(local_result, rtype, result);
|
||||
}
|
||||
|
||||
#endif // AC_MPI_ENABLED
|
||||
|
Reference in New Issue
Block a user