device_gemm_bias_add_reduce_xdl_cshuffle.hpp Source File#
device_gemm_bias_add_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
23// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
80struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()>
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__global__ void kernel_gemm_bias_add_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC0 *__restrict__ p_bias_grid, const FloatC1 *__restrict__ p_d0_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const C1ElementwiseOperation c1_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c0_grid_desc_mblock_mperblock_nblock_nperblock, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:45
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:180
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp:279
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:455
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:506
ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle::Argument::b_grid_desc_bk0_n_bk1_
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:505
CElementwiseOperation c_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:513
CDataType * p_c_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:500
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, const BiasDataType *p_bias_grid, const D0DataType *p_d0_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideC1, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, D0ElementwiseOperation d0_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:456
const BiasDataType * p_bias_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:501
GridwiseGemm64::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:510
AElementwiseOperation a_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:511
ReducePtrsGlobal p_reduces_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:503
const D0DataType * p_d0_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:502
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:516
D0ElementwiseOperation d0_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:514
BElementwiseOperation b_element_op_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:512
C0GridDesc_M_N c0_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:507
ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle::Argument::a_grid_desc_ak0_m_ak1_
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:504
const ADataType * p_a_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:498
C1GridDesc_M_N c1_grid_desc_m_n_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:508
ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle::Argument::reduce_in_element_ops_
ReduceInElementwiseOperations reduce_in_element_ops_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:515
ReduceGridDesc_M reduce_grid_desc_m_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:509
const BDataType * p_b_grid_
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:499
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:521
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:525
DeviceOp::Argument Argument
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:522
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:662
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:81
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:195
GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, BiasDataType, D0DataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, D0ElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, C0GridDesc_M_N, C1GridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemmBase
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:391
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:387
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:669
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:451
static constexpr int NumReduce
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:710
static constexpr auto I2
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:90
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:357
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:859
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) C1GridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:386
static constexpr auto I0
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:88
static auto MakeInvoker()
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:782
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:705
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:298
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:675
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:382
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:92
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:450
std::string GetTypeString() const override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:865
decltype(MakeCGridDescriptor_M_N(1, 1, 0)) C0GridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:385
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:85
static constexpr auto NXdlPerWave32
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:86
DeviceGemmBiasAddReduce_Xdl_CShuffle DeviceOp
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:82
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 1 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 1 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 1 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, index_t=1) override
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:786
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:383
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:384
static constexpr auto I1
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:89
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 1 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 1 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 1 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition device_gemm_bias_add_reduce_xdl_cshuffle.hpp:711
Definition device_gemm_reduce.hpp:17