diff --git a/src/core/device.cc b/src/core/device.cc index c772449..f60554d 100644 --- a/src/core/device.cc +++ b/src/core/device.cc @@ -748,98 +748,181 @@ acDeviceBoundStepMPI(const Device device) const int mz = device->local_config.int_params[AC_mz]; const size_t count = mx * my * NGHOST; - for (int isubstep = 0; isubstep < 3; ++isubstep) { - acDeviceSynchronizeStream(device, STREAM_ALL); - // Local boundconds - for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { - // Front plate local - { - const int3 start = (int3){0, 0, NGHOST}; - const int3 end = (int3){mx, my, 2 * NGHOST}; - acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); - } - // Back plate local - { - const int3 start = (int3){0, 0, mz - 2 * NGHOST}; - const int3 end = (int3){mx, my, mz - NGHOST}; - acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); - } + acDeviceSynchronizeStream(device, STREAM_ALL); + // Local boundconds + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + // Front plate local + { + const int3 start = (int3){0, 0, NGHOST}; + const int3 end = (int3){mx, my, 2 * NGHOST}; + acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); } -#define INNER_BOUNDCOND_STREAM ((Stream)(NUM_STREAMS - 1)) - // Inner boundconds (while waiting) - for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { - - const int3 start = (int3){0, 0, 2 * NGHOST}; - const int3 end = (int3){mx, my, mz - 2 * NGHOST}; - acDevicePeriodicBoundcondStep(device, INNER_BOUNDCOND_STREAM, (VertexBufferHandle)i, - start, end); + // Back plate local + { + const int3 start = (int3){0, 0, mz - 2 * NGHOST}; + const int3 end = (int3){mx, my, mz - NGHOST}; + acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); } - - // MPI - MPI_Request recv_requests[2 * NUM_VTXBUF_HANDLES]; - MPI_Datatype datatype = MPI_FLOAT; - if (sizeof(AcReal) == 8) - datatype = MPI_DOUBLE; - - int pid, num_processes; - MPI_Comm_rank(MPI_COMM_WORLD, &pid); - MPI_Comm_size(MPI_COMM_WORLD, &num_processes); - - for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { - { // Recv neighbor's front - // ...|ooooxxx|... -> xxx|ooooooo|... - const size_t dst_idx = acVertexBufferIdx(0, 0, 0, device->local_config); - const int recv_pid = (pid + num_processes - 1) % num_processes; - - MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, i, MPI_COMM_WORLD, - &recv_requests[i]); - } - { // Recv neighbor's back - // ...|ooooooo|xxx <- ...|xxxoooo|... - const size_t dst_idx = acVertexBufferIdx(0, 0, mz - NGHOST, device->local_config); - const int recv_pid = (pid + 1) % num_processes; - - MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, - NUM_VTXBUF_HANDLES + i, MPI_COMM_WORLD, - &recv_requests[i + NUM_VTXBUF_HANDLES]); - } - } - - for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { - acDeviceSynchronizeStream(device, (Stream)i); - { - // Send front - // ...|ooooxxx|... -> xxx|ooooooo|... - const size_t src_idx = acVertexBufferIdx(0, 0, mz - 2 * NGHOST, - device->local_config); - const int send_pid = (pid + 1) % num_processes; - - MPI_Request request; - MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, i, MPI_COMM_WORLD, - &request); - } - { // Send back - // ...|ooooooo|xxx <- ...|xxxoooo|... - const size_t src_idx = acVertexBufferIdx(0, 0, NGHOST, device->local_config); - const int send_pid = (pid + num_processes - 1) % num_processes; - - MPI_Request request; - MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, - i + NUM_VTXBUF_HANDLES, MPI_COMM_WORLD, &request); - } - } - for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { - MPI_Status status; - MPI_Wait(&recv_requests[i], &status); - MPI_Wait(&recv_requests[i + NUM_VTXBUF_HANDLES], &status); - } - MPI_Barrier(MPI_COMM_WORLD); - acDeviceSwapBuffers(device); - MPI_Barrier(MPI_COMM_WORLD); } +#define INNER_BOUNDCOND_STREAM ((Stream)(NUM_STREAMS - 1)) + // Inner boundconds (while waiting) + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + + const int3 start = (int3){0, 0, 2 * NGHOST}; + const int3 end = (int3){mx, my, mz - 2 * NGHOST}; + acDevicePeriodicBoundcondStep(device, INNER_BOUNDCOND_STREAM, (VertexBufferHandle)i, start, + end); + } + + // MPI + MPI_Request recv_requests[2 * NUM_VTXBUF_HANDLES]; + MPI_Datatype datatype = MPI_FLOAT; + if (sizeof(AcReal) == 8) + datatype = MPI_DOUBLE; + + int pid, num_processes; + MPI_Comm_rank(MPI_COMM_WORLD, &pid); + MPI_Comm_size(MPI_COMM_WORLD, &num_processes); + + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + { // Recv neighbor's front + // ...|ooooxxx|... -> xxx|ooooooo|... + const size_t dst_idx = acVertexBufferIdx(0, 0, 0, device->local_config); + const int recv_pid = (pid + num_processes - 1) % num_processes; + + MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, i, MPI_COMM_WORLD, + &recv_requests[i]); + } + { // Recv neighbor's back + // ...|ooooooo|xxx <- ...|xxxoooo|... + const size_t dst_idx = acVertexBufferIdx(0, 0, mz - NGHOST, device->local_config); + const int recv_pid = (pid + 1) % num_processes; + + MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, + NUM_VTXBUF_HANDLES + i, MPI_COMM_WORLD, + &recv_requests[i + NUM_VTXBUF_HANDLES]); + } + } + + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + acDeviceSynchronizeStream(device, (Stream)i); + { + // Send front + // ...|ooooxxx|... -> xxx|ooooooo|... + const size_t src_idx = acVertexBufferIdx(0, 0, mz - 2 * NGHOST, device->local_config); + const int send_pid = (pid + 1) % num_processes; + + MPI_Request request; + MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, i, MPI_COMM_WORLD, + &request); + } + { // Send back + // ...|ooooooo|xxx <- ...|xxxoooo|... + const size_t src_idx = acVertexBufferIdx(0, 0, NGHOST, device->local_config); + const int send_pid = (pid + num_processes - 1) % num_processes; + + MPI_Request request; + MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, + i + NUM_VTXBUF_HANDLES, MPI_COMM_WORLD, &request); + } + } + MPI_Waitall(NUM_VTXBUF_HANDLES, recv_requests, MPI_STATUSES_IGNORE); + return AC_SUCCESS; +} + +/* +// 1D decomp +static AcResult +acDeviceBoundStepMPI(const Device device) +{ + const int mx = device->local_config.int_params[AC_mx]; + const int my = device->local_config.int_params[AC_my]; + const int mz = device->local_config.int_params[AC_mz]; + const size_t count = mx * my * NGHOST; + + acDeviceSynchronizeStream(device, STREAM_ALL); + // Local boundconds + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + // Front plate local + { + const int3 start = (int3){0, 0, NGHOST}; + const int3 end = (int3){mx, my, 2 * NGHOST}; + acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); + } + // Back plate local + { + const int3 start = (int3){0, 0, mz - 2 * NGHOST}; + const int3 end = (int3){mx, my, mz - NGHOST}; + acDevicePeriodicBoundcondStep(device, (Stream)i, (VertexBufferHandle)i, start, end); + } + } +#define INNER_BOUNDCOND_STREAM ((Stream)(NUM_STREAMS - 1)) + // Inner boundconds (while waiting) + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + + const int3 start = (int3){0, 0, 2 * NGHOST}; + const int3 end = (int3){mx, my, mz - 2 * NGHOST}; + acDevicePeriodicBoundcondStep(device, INNER_BOUNDCOND_STREAM, (VertexBufferHandle)i, start, + end); + } + + // MPI + MPI_Request recv_requests[2 * NUM_VTXBUF_HANDLES]; + MPI_Datatype datatype = MPI_FLOAT; + if (sizeof(AcReal) == 8) + datatype = MPI_DOUBLE; + + int pid, num_processes; + MPI_Comm_rank(MPI_COMM_WORLD, &pid); + MPI_Comm_size(MPI_COMM_WORLD, &num_processes); + + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + { // Recv neighbor's front + // ...|ooooxxx|... -> xxx|ooooooo|... + const size_t dst_idx = acVertexBufferIdx(0, 0, 0, device->local_config); + const int recv_pid = (pid + num_processes - 1) % num_processes; + + MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, i, MPI_COMM_WORLD, + &recv_requests[i]); + } + { // Recv neighbor's back + // ...|ooooooo|xxx <- ...|xxxoooo|... + const size_t dst_idx = acVertexBufferIdx(0, 0, mz - NGHOST, device->local_config); + const int recv_pid = (pid + 1) % num_processes; + + MPI_Irecv(&device->vba.in[i][dst_idx], count, datatype, recv_pid, + NUM_VTXBUF_HANDLES + i, MPI_COMM_WORLD, + &recv_requests[i + NUM_VTXBUF_HANDLES]); + } + } + + for (int i = 0; i < NUM_VTXBUF_HANDLES; ++i) { + acDeviceSynchronizeStream(device, (Stream)i); + { + // Send front + // ...|ooooxxx|... -> xxx|ooooooo|... + const size_t src_idx = acVertexBufferIdx(0, 0, mz - 2 * NGHOST, device->local_config); + const int send_pid = (pid + 1) % num_processes; + + MPI_Request request; + MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, i, MPI_COMM_WORLD, + &request); + } + { // Send back + // ...|ooooooo|xxx <- ...|xxxoooo|... + const size_t src_idx = acVertexBufferIdx(0, 0, NGHOST, device->local_config); + const int send_pid = (pid + num_processes - 1) % num_processes; + + MPI_Request request; + MPI_Isend(&device->vba.in[i][src_idx], count, datatype, send_pid, + i + NUM_VTXBUF_HANDLES, MPI_COMM_WORLD, &request); + } + } + MPI_Waitall(NUM_VTXBUF_HANDLES, recv_requests, MPI_STATUSES_IGNORE); return AC_SUCCESS; } +*/ // 1D decomp static AcResult