diff --git a/src/core/device.cc b/src/core/device.cc index a6ec793..52f20ed 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -433,6 +433,27 @@ acDevicePeriodicBoundconds(const Device device, const Stream stream, const int3 return AC_SUCCESS; } +AcResult +acDeviceGeneralBoundcondStep(const Device device, const Stream stream, + const VertexBufferHandle vtxbuf_handle, const int3 start, + const int3 end) +{ + cudaSetDevice(device->id); + return acKernelGeneralBoundconds(device->streams[stream], start, end, + device->vba.in[vtxbuf_handle]); +} + +AcResult +acDeviceGeneralBoundconds(const Device device, const Stream stream, const int3 start, + const int3 end) +{ + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + acDeviceGeneralBoundcondStep(device, stream, (VertexBufferHandle)i, start, end); + } + return AC_SUCCESS; +} + + AcResult acDeviceReduceScal(const Device device, const Stream stream, const ReductionType rtype, const VertexBufferHandle vtxbuf_handle, AcReal* result) @@ -1652,6 +1673,13 @@ acGridIntegrate(const Stream stream, const AcReal dt) acGridLoadScalarUniform(stream, AC_dt, dt); acDeviceSynchronizeStream(device, stream); + // Check the position in MPI frame + int nprocs, pid; + MPI_Comm_size(MPI_COMM_WORLD, &nprocs); + MPI_Comm_rank(MPI_COMM_WORLD, &pid); + const uint3_64 decomposition = decompose(nprocs); + const int3 pid3d = getPid3D(pid, decomposition); + // Corners #if MPI_INCL_CORNERS // Do not rm: required for corners @@ -1805,6 +1833,21 @@ acGridIntegrate(const Stream stream, const AcReal dt) acSyncCommData(sidexz_data); acSyncCommData(sideyz_data); #endif // MPI_COMM_ENABLED + + + // Set outer boudaries after substep computation. + const int3 m1 = (int3){0, 0, 0}; + const int3 m2 = nn; + const int3 pid3d = getPid3D(pid, decomposition); + // If we are are a boundary element + if ((pid3d.x == 0) || (pid3d.x == decomposition.x - 1) || + (pid3d.y == 0) || (pid3d.y == decomposition.y - 1) || + (pid3d.z == 0) || (pid3d.z == decomposition.z - 1) ||) + { + acDeviceGeneralBoundconds(device, stream, m1, m2); + } + acGridSynchronizeStream(stream); + #if MPI_COMPUTE_ENABLED { // Front const int3 m1 = (int3){NGHOST, NGHOST, NGHOST}; @@ -1840,8 +1883,12 @@ acGridIntegrate(const Stream stream, const AcReal dt) acDeviceSwapBuffers(device); acDeviceSynchronizeStream(device, STREAM_ALL); // Wait until inner and outer done //////////////////////////////////////////// + } + + + return AC_SUCCESS; }