Simplified the logic used for calculating reductions
This commit is contained in:
@@ -847,7 +847,7 @@ oob(const int& i, const int& j, const int& k)
|
|||||||
|
|
||||||
template <ReduceInitialScalFunc reduce_initial>
|
template <ReduceInitialScalFunc reduce_initial>
|
||||||
__global__ void
|
__global__ void
|
||||||
_kernel_reduce_scal(const __restrict__ AcReal* src, AcReal* dst)
|
_kernel_reduce_initial_scal(const __restrict__ AcReal* src, AcReal* dst)
|
||||||
{
|
{
|
||||||
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
const int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
const int j = threadIdx.y + blockIdx.y * blockDim.y;
|
const int j = threadIdx.y + blockIdx.y * blockDim.y;
|
||||||
@@ -867,7 +867,7 @@ _kernel_reduce_scal(const __restrict__ AcReal* src, AcReal* dst)
|
|||||||
|
|
||||||
template <ReduceInitialVecFunc reduce_initial>
|
template <ReduceInitialVecFunc reduce_initial>
|
||||||
__global__ void
|
__global__ void
|
||||||
_kernel_reduce_vec(const __restrict__ AcReal* src_a,
|
_kernel_reduce_initial_vec(const __restrict__ AcReal* src_a,
|
||||||
const __restrict__ AcReal* src_b,
|
const __restrict__ AcReal* src_b,
|
||||||
const __restrict__ AcReal* src_c, AcReal* dst)
|
const __restrict__ AcReal* src_c, AcReal* dst)
|
||||||
{
|
{
|
||||||
@@ -964,40 +964,26 @@ reduce_scal(const cudaStream_t stream,
|
|||||||
const int bpg2 = (unsigned int)ceil(AcReal(scratchpad_size) /
|
const int bpg2 = (unsigned int)ceil(AcReal(scratchpad_size) /
|
||||||
AcReal(ELEMS_PER_THREAD * BLOCK_SIZE));
|
AcReal(ELEMS_PER_THREAD * BLOCK_SIZE));
|
||||||
|
|
||||||
switch (rtype) {
|
if (rtype == RTYPE_MAX || rtype == RTYPE_MIN) {
|
||||||
case RTYPE_MAX:
|
_kernel_reduce_initial_scal<dvalue><<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
||||||
_kernel_reduce_scal<dvalue>
|
} else if (rtype == RTYPE_RMS) {
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
_kernel_reduce_initial_scal<dsquared><<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
||||||
_kernel_reduce<dmax>
|
} else if (rtype == RTYPE_RMS_EXP) {
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
_kernel_reduce_initial_scal<dexp_squared><<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
||||||
_kernel_reduce_block<dmax>
|
} else {
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
ERROR("Unrecognized RTYPE");
|
||||||
break;
|
}
|
||||||
case RTYPE_MIN:
|
|
||||||
_kernel_reduce_scal<dvalue>
|
if (rtype == RTYPE_MAX) {
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
_kernel_reduce<dmax><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
_kernel_reduce<dmin>
|
_kernel_reduce_block<dmax><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
} else if (rtype == RTYPE_MIN) {
|
||||||
_kernel_reduce_block<dmin>
|
_kernel_reduce<dmin><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
_kernel_reduce_block<dmin><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
break;
|
} else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) {
|
||||||
case RTYPE_RMS:
|
_kernel_reduce<dsum><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
_kernel_reduce_scal<dsquared>
|
_kernel_reduce_block<dsum><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
} else {
|
||||||
_kernel_reduce<dsum>
|
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
_kernel_reduce_block<dsum>
|
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
break;
|
|
||||||
case RTYPE_RMS_EXP:
|
|
||||||
_kernel_reduce_scal<dexp_squared>
|
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer, reduce_scratchpad);
|
|
||||||
_kernel_reduce<dsum>
|
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
_kernel_reduce_block<dsum>
|
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ERROR("Unrecognized RTYPE");
|
ERROR("Unrecognized RTYPE");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1008,10 +994,9 @@ reduce_scal(const cudaStream_t stream,
|
|||||||
|
|
||||||
AcReal
|
AcReal
|
||||||
reduce_vec(const cudaStream_t stream,
|
reduce_vec(const cudaStream_t stream,
|
||||||
const ReductionType& rtype, const int& nx, const int& ny,
|
const ReductionType& rtype, const int& nx, const int& ny, const int& nz,
|
||||||
const int& nz, const AcReal* vertex_buffer_a,
|
const AcReal* vec0, const AcReal* vec1, const AcReal* vec2,
|
||||||
const AcReal* vertex_buffer_b, const AcReal* vertex_buffer_c,
|
AcReal* reduce_scratchpad, AcReal* reduce_result)
|
||||||
AcReal* reduce_scratchpad, AcReal* reduce_result)
|
|
||||||
{
|
{
|
||||||
const dim3 tpb(32, 4, 1);
|
const dim3 tpb(32, 4, 1);
|
||||||
const dim3 bpg(int(ceil(float(nx) / tpb.x)),
|
const dim3 bpg(int(ceil(float(nx) / tpb.x)),
|
||||||
@@ -1037,44 +1022,26 @@ reduce_vec(const cudaStream_t stream,
|
|||||||
ERRCHK_ALWAYS(is_power_of_two(ny));
|
ERRCHK_ALWAYS(is_power_of_two(ny));
|
||||||
ERRCHK_ALWAYS(is_power_of_two(nz));
|
ERRCHK_ALWAYS(is_power_of_two(nz));
|
||||||
|
|
||||||
switch (rtype) {
|
if (rtype == RTYPE_MAX || rtype == RTYPE_MIN) {
|
||||||
case RTYPE_MAX:
|
_kernel_reduce_initial_vec<dlength_vec><<<bpg, tpb, 0, stream>>>(vec0, vec1, vec2, reduce_scratchpad);
|
||||||
_kernel_reduce_vec<dlength_vec>
|
} else if (rtype == RTYPE_RMS) {
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer_a, vertex_buffer_b, vertex_buffer_c,
|
_kernel_reduce_initial_vec<dsquared_vec><<<bpg, tpb, 0, stream>>>(vec0, vec1, vec2, reduce_scratchpad);
|
||||||
reduce_scratchpad);
|
} else if (rtype == RTYPE_RMS_EXP) {
|
||||||
_kernel_reduce<dmax>
|
_kernel_reduce_initial_vec<dexp_squared_vec><<<bpg, tpb, 0, stream>>>(vec0, vec1, vec2, reduce_scratchpad);
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
} else {
|
||||||
_kernel_reduce_block<dmax>
|
ERROR("Unrecognized RTYPE");
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
}
|
||||||
break;
|
|
||||||
case RTYPE_MIN:
|
if (rtype == RTYPE_MAX) {
|
||||||
_kernel_reduce_vec<dlength_vec>
|
_kernel_reduce<dmax><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer_a, vertex_buffer_b, vertex_buffer_c,
|
_kernel_reduce_block<dmax><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
reduce_scratchpad);
|
} else if (rtype == RTYPE_MIN) {
|
||||||
_kernel_reduce<dmin>
|
_kernel_reduce<dmin><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
_kernel_reduce_block<dmin><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
_kernel_reduce_block<dmin>
|
} else if (rtype == RTYPE_RMS || rtype == RTYPE_RMS_EXP) {
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
_kernel_reduce<dsum><<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
break;
|
_kernel_reduce_block<dsum><<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
||||||
case RTYPE_RMS:
|
} else {
|
||||||
_kernel_reduce_vec<dsquared_vec>
|
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer_a, vertex_buffer_b, vertex_buffer_c,
|
|
||||||
reduce_scratchpad);
|
|
||||||
_kernel_reduce<dsum>
|
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
_kernel_reduce_block<dsum>
|
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
break;
|
|
||||||
case RTYPE_RMS_EXP:
|
|
||||||
_kernel_reduce_vec<dexp_squared_vec>
|
|
||||||
<<<bpg, tpb, 0, stream>>>(vertex_buffer_a, vertex_buffer_b, vertex_buffer_c,
|
|
||||||
reduce_scratchpad);
|
|
||||||
_kernel_reduce<dsum>
|
|
||||||
<<<bpg2, BLOCK_SIZE, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
_kernel_reduce_block<dsum>
|
|
||||||
<<<1, 1, 0, stream>>>(reduce_scratchpad, reduce_result);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
ERROR("Unrecognized RTYPE");
|
ERROR("Unrecognized RTYPE");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user