16template <
typename GridwiseReduction,
22 typename IndexDataType,
23 typename InGridDesc_M_K,
24 typename OutGridDesc_M,
25 typename InElementwiseOperation,
26 typename AccElementwiseOperation>
28 const OutGridDesc_M out_grid_desc_m,
29 const InElementwiseOperation in_elementwise_op,
30 const AccElementwiseOperation acc_elementwise_op,
32 index_t num_k_block_tile_iteration,
34 const InDataType*
const __restrict__ p_in_value_global,
35 const IndexDataType*
const __restrict__ p_in_index_global,
37 OutDataType*
const __restrict__ p_out_value_global,
38 IndexDataType*
const __restrict__ p_out_index_global)
40 if constexpr(!OutputIndex)
42 (void)p_in_index_global;
43 (void)p_out_index_global;
45 GridwiseReduction::Run(in_grid_desc_m_k,
50 num_k_block_tile_iteration,
58 GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
62 num_k_block_tile_iteration,
72template <
typename InDataType,
75 typename IndexDataType,
76 typename InGridDesc_M_K,
77 typename OutGridDesc_M,
78 typename ReduceOperation,
79 typename InElementwiseOperation,
80 typename AccElementwiseOperation,
93 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
94 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
95 (MThreadSliceSize % OutDstVectorSize == 0),
96 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
139 __device__
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
140 const OutGridDesc_M& out_grid_desc_m,
141 const InElementwiseOperation& in_elementwise_op,
142 const AccElementwiseOperation& acc_elementwise_op,
144 index_t num_k_block_tile_iteration,
146 const InDataType*
const __restrict__ p_in_value_global,
148 OutDataType*
const __restrict__ p_out_value_global)
150 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
153 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
157 in_grid_desc_m_k.GetElementSpaceSize(),
158 ReduceOperation::template GetIdentityValue<InDataType>());
160 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
162 auto reduce_work_buf =
174 const index_t blkgroup_id = block_global_id / block_group_size;
175 const index_t block_local_id = block_global_id % block_group_size;
177 const auto thread_cluster_idx =
180 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
181 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
192 decltype(thread_buffer_desc),
201 block_local_id * reduceSizePerBlock +
202 thread_k_cluster_id * KThreadSliceSize));
209 threadwise_src_load.Run(in_grid_desc_m_k,
218 constexpr auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
226 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
229 }
while(reducedTiles < num_k_block_tile_iteration);
237 if(thread_k_cluster_id == 0)
239 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
241 accu_value_buf(I) *= alpha;
245 if(thread_k_cluster_id == 0)
252 auto threadwise_dst_load =
256 decltype(reduced_data_desc),
265 thread_m_cluster_id * MThreadSliceSize));
267 threadwise_dst_load.Run(out_grid_desc_m,
278 auto threadwise_dst_store =
281 decltype(reduced_data_desc),
288 OutMemoryDataOperation,
293 thread_m_cluster_id * MThreadSliceSize),
296 threadwise_dst_store.Run(reduced_data_desc,
304 template <
bool HaveIndexInput>
305 __device__
static void RunWithIndex(
const InGridDesc_M_K& in_grid_desc_m_k,
306 const OutGridDesc_M& out_grid_desc_m,
307 const InElementwiseOperation in_elementwise_op,
308 const AccElementwiseOperation acc_elementwise_op,
309 index_t num_k_block_tile_iteration,
311 const InDataType*
const __restrict__ p_in_value_global,
312 const IndexDataType*
const __restrict__ p_in_index_global,
314 OutDataType*
const __restrict__ p_out_value_global,
315 IndexDataType*
const __restrict__ p_out_index_global)
317 using BlockwiseReduceWithIndex =
331 (void)in_elementwise_op;
334 __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
335 __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
337 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
341 in_grid_desc_m_k.GetElementSpaceSize(),
342 ReduceOperation::template GetIdentityValue<InDataType>());
344 p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
346 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
348 p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
350 auto reduce_work_val_buf =
352 auto reduce_work_idx_buf =
360 MThreadSliceSize * KThreadSliceSize,
370 const auto thread_cluster_idx =
373 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
374 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
380 auto threadwise_src_val_load =
384 decltype(thread_buffer_desc),
393 thread_m_cluster_id * MThreadSliceSize,
394 thread_k_cluster_id * KThreadSliceSize));
397 accu_value_buf(I) = identityVal;
398 accu_index_buf(I) = 0;
405 if constexpr(HaveIndexInput)
407 auto threadwise_src_idx_load =
411 decltype(thread_buffer_desc),
420 thread_m_cluster_id * MThreadSliceSize,
421 thread_k_cluster_id * KThreadSliceSize));
426 threadwise_src_val_load.Run(in_grid_desc_m_k,
431 threadwise_src_idx_load.Run(in_grid_desc_m_k,
438 AccDataType tmpValue = identityVal;
439 IndexDataType tmpIndex = 0;
442 constexpr auto offset =
443 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
445 AccumulationWithIndex::Calculate(tmpValue,
451 BlockwiseReduceWithIndex::Reduce(
452 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
454 AccumulationWithIndex::Calculate(
455 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
458 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
459 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
462 }
while(reducedTiles < num_k_block_tile_iteration);
471 threadwise_src_val_load.Run(in_grid_desc_m_k,
479 constexpr auto offset =
480 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
484 indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
491 AccDataType tmpValue = identityVal;
492 IndexDataType tmpIndex = 0;
495 constexpr auto offset =
496 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
498 AccumulationWithIndex::Calculate(tmpValue,
504 BlockwiseReduceWithIndex::Reduce(
505 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
507 AccumulationWithIndex::Calculate(
508 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
511 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
515 }
while(reducedTiles < num_k_block_tile_iteration);
521 if(thread_k_cluster_id == 0)
524 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
526 accu_value_buf(I) *= alpha;
530 if(thread_k_cluster_id == 0)
537 auto threadwise_dst_load =
541 decltype(reduced_data_desc),
550 thread_m_cluster_id * MThreadSliceSize));
552 threadwise_dst_load.Run(out_grid_desc_m,
563 auto threadwise_dst_val_store =
566 decltype(reduced_data_desc),
578 thread_m_cluster_id * MThreadSliceSize),
581 auto threadwise_dst_idx_store =
584 decltype(reduced_data_desc),
596 thread_m_cluster_id * MThreadSliceSize),
599 threadwise_dst_val_store.Run(reduced_data_desc,
604 threadwise_dst_idx_store.
Run(reduced_data_desc,
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_multiblock.hpp:27
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition gridwise_2d_reduction_multiblock.hpp:92
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_2d_reduction_multiblock.hpp:111
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_2d_reduction_multiblock.hpp:105
static constexpr bool reorder_thread_cluster
Definition gridwise_2d_reduction_multiblock.hpp:98
ThreadwiseReduction< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M, ReduceOperation, PropagateNan > ThreadwiseReduce
Definition gridwise_2d_reduction_multiblock.hpp:123
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition gridwise_2d_reduction_multiblock.hpp:139
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_2d_reduction_multiblock.hpp:102
static constexpr index_t M_BlockTileSize
Definition gridwise_2d_reduction_multiblock.hpp:134
static constexpr auto I0
Definition gridwise_2d_reduction_multiblock.hpp:131
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_2d_reduction_multiblock.hpp:100
static constexpr auto thread_cluster_desc
Definition gridwise_2d_reduction_multiblock.hpp:108
static constexpr auto I1
Definition gridwise_2d_reduction_multiblock.hpp:132
PartitionedBlockwiseReduction< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, ReduceOperation, PropagateNan > BlockwiseReduce
Definition gridwise_2d_reduction_multiblock.hpp:116
detail::AccumulateWithNanCheck< PropagateNan, ReduceOperation, AccDataType > Accumulation
Definition gridwise_2d_reduction_multiblock.hpp:137
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_multiblock.hpp:305
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_2d_reduction_multiblock.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_2d_reduction_multiblock.hpp:113
static constexpr index_t K_BlockTileSize
Definition gridwise_2d_reduction_multiblock.hpp:135
Definition reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition reduction_functions_blockwise.hpp:44
Definition reduction_functions_blockwise.hpp:175
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition reduction_functions_threadwise.hpp:36
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/functional.hpp:100
Definition reduction_functions_accumulate.hpp:65
Definition reduction_functions_accumulate.hpp:28
Definition reduction_common.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340