diff --git a/src/core/kernels/kernels.cuh b/src/core/kernels/kernels.cuh index 9dda790..d7985c6 100644 --- a/src/core/kernels/kernels.cuh +++ b/src/core/kernels/kernels.cuh @@ -810,6 +810,8 @@ typedef AcReal (*ReduceInitialScalFunc)(const AcReal&); typedef AcReal (*ReduceInitialVecFunc)(const AcReal&, const AcReal&, const AcReal&); +typedef AcReal (*FilterFunc)(const AcReal&); + // clang-format off /* Comparison funcs */ static __device__ inline AcReal @@ -841,6 +843,83 @@ static __device__ inline AcReal dexp_squared_vec(const AcReal& a, const AcReal& b, const AcReal& c) { return dexp_squared(a) + dexp_squared(b) + dexp_squared(c); } // clang-format on +#include +template +__global__ void +kernel_filter(const __restrict__ AcReal* src, const int3 start, const int3 end, AcReal* dst) +{ + const int3 src_idx = (int3) { + start.x + threadIdx.x + blockIdx.x * blockDim.x, + start.y + threadIdx.y + blockIdx.y * blockDim.y, + start.z + threadIdx.z + blockIdx.z * blockDim.z + }; + + const int nx = end.x - start.x; + const int ny = end.y - start.y; + const int nz = end.z - start.z; + const int3 dst_idx = (int3) { + threadIdx.x + blockIdx.x * blockDim.x, + threadIdx.y + blockIdx.y * blockDim.y, + threadIdx.z + blockIdx.z * blockDim.z + }; + + assert(src_idx.x < DCONST_INT(AC_nx_max) && src_idx.y < DCONST_INT(AC_ny_max) && src_idx.z < DCONST_INT(AC_nz_max)); + assert(dst_idx.x < nx && dst_idx.y < ny && dst_idx.z < nz); + assert(dst_idx.x + dst_idx.y * nx + dst_idx.z * nx * ny < nx * ny * nz); + + dst[dst_idx.x + dst_idx.y * nx + dst_idx.z * nx * ny] = filter(src[IDX(src_idx)]); +} + +template +__global__ void +kernel_reduce(AcReal* scratchpad, const int num_elems) +{ + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + + extern __shared__ AcReal smem[]; + if (idx < num_elems) { + smem[threadIdx.x] = scratchpad[idx]; + } else { + smem[threadIdx.x] = NAN; + } + __syncthreads(); + + int offset = blockDim.x / 2; + assert(offset % 2 == 0); + while (offset > 0) { + if (threadIdx.x < offset) { + smem[threadIdx.x] = reduce(smem[threadIdx.x], smem[threadIdx.x + offset]); + } + offset /= 2; + __syncthreads(); + } + + if (threadIdx.x == 0) { + scratchpad[idx] = smem[threadIdx.x]; + } +} + +template +__global__ void +kernel_reduce_block(const __restrict__ AcReal* scratchpad, + const int num_blocks, const int block_size, + AcReal* result) +{ + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx != 0) { + return; + } + const int scratchpad_size = DCONST_INT(AC_nxyz); + AcReal res = scratchpad[0]; + for (int i = 1; i < num_blocks; ++i) { + assert(i * block_size < num_blocks * block_size); + assert(i * block_size < scratchpad_size); + res = reduce(res, scratchpad[i * block_size]); + } + *result = res; +} + + static __device__ inline bool oob(const int& i, const int& j, const int& k) { @@ -972,7 +1051,11 @@ reduce_scal(const cudaStream_t stream, AcReal(ELEMS_PER_THREAD * BLOCK_SIZE)); if (rtype == RTYPE_MAX || rtype == RTYPE_MIN) { - kernel_reduce_1of3<<>>(vtxbuf, scratchpad); + const int3 start = (int3){STENCIL_ORDER/2, STENCIL_ORDER/2, STENCIL_ORDER/2}; + const int3 end = (int3){start.x + nx, start.y + ny, start.z + nz}; + kernel_filter<<>>(vtxbuf, start, end, scratchpad); + ERRCHK_CUDA_KERNEL_ALWAYS(); + //kernel_reduce_1of3<<>>(vtxbuf, scratchpad); } else if (rtype == RTYPE_RMS) { kernel_reduce_1of3<<>>(vtxbuf, scratchpad); } else if (rtype == RTYPE_RMS_EXP) { @@ -982,8 +1065,19 @@ reduce_scal(const cudaStream_t stream, } if (rtype == RTYPE_MAX) { - kernel_reduce_2of3<<>>(scratchpad, reduce_result); - kernel_reduce_3of3<<<1, 1, 0, stream>>>(scratchpad, reduce_result); + const int3 start = (int3){STENCIL_ORDER/2, STENCIL_ORDER/2, STENCIL_ORDER/2}; + const int3 end = (int3){start.x + nx, start.y + ny, start.z + nz}; + const int num_elems = (end.x - start.x) * (end.y - start.y) * (end.z - start.z); + const int tpb = 128; + const int bpg = num_elems / tpb; + assert(num_elems % tpb == 0); + assert(tpb * bpg <= num_elems); + kernel_reduce<<>>(scratchpad, num_elems); + ERRCHK_CUDA_KERNEL_ALWAYS(); + kernel_reduce_block<<<1, 1, 0, stream>>>(scratchpad, bpg, tpb, reduce_result); + ERRCHK_CUDA_KERNEL_ALWAYS(); + // kernel_reduce_2of3<<>>(scratchpad, reduce_result); + // kernel_reduce_3of3<<<1, 1, 0, stream>>>(scratchpad, reduce_result); } else if (rtype == RTYPE_MIN) { kernel_reduce_2of3<<>>(scratchpad, reduce_result); kernel_reduce_3of3<<<1, 1, 0, stream>>>(scratchpad, reduce_result);