fused_moegemm_kernel.hpp Source File#
fused_moegemm_kernel.hpp
Go to the documentation of this file.
28// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
30// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
35// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
37// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
43// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
44// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
182 _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
183 _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
#define _TS_
#define _SS_
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition coordinate_transform.hpp:1680
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fused_moegemm_kernel.hpp:98
const void * sorted_expert_ids_ptr
Definition fused_moegemm_kernel.hpp:110
const void * num_sorted_tiles_ptr
Definition fused_moegemm_kernel.hpp:111
const void * a_scale_ptr
Definition fused_moegemm_kernel.hpp:100
const void * g_scale_ptr
Definition fused_moegemm_kernel.hpp:103
const void * sorted_weight_ptr
Definition fused_moegemm_kernel.hpp:109
const void * d_scale_ptr
Definition fused_moegemm_kernel.hpp:104
index_t num_experts
Definition fused_moegemm_kernel.hpp:116
const void * sorted_token_ids_ptr
Definition fused_moegemm_kernel.hpp:108
index_t intermediate_size
Definition fused_moegemm_kernel.hpp:114
index_t hidden_size
Definition fused_moegemm_kernel.hpp:113
const void * y_smooth_scale_ptr
Definition fused_moegemm_kernel.hpp:105
index_t stride_token
Definition fused_moegemm_kernel.hpp:119
Definition fused_moegemm_kernel.hpp:191
index_t topk
Definition fused_moegemm_kernel.hpp:210
void * o_ptr
Definition fused_moegemm_kernel.hpp:199
const void * sorted_expert_ids_ptr
Definition fused_moegemm_kernel.hpp:203
index_t intermediate_size
Definition fused_moegemm_kernel.hpp:207
index_t hidden_size
Definition fused_moegemm_kernel.hpp:206
const void * y_smooth_scale_ptr
Definition fused_moegemm_kernel.hpp:198
const void * a_ptr
Definition fused_moegemm_kernel.hpp:192
index_t num_tokens
Definition fused_moegemm_kernel.hpp:208
const void * g_scale_ptr
Definition fused_moegemm_kernel.hpp:196
const void * d_ptr
Definition fused_moegemm_kernel.hpp:195
index_t num_experts
Definition fused_moegemm_kernel.hpp:209
const void * a_scale_ptr
Definition fused_moegemm_kernel.hpp:193
const void * sorted_weight_ptr
Definition fused_moegemm_kernel.hpp:202
const void * g_ptr
Definition fused_moegemm_kernel.hpp:194
index_t stride_token
Definition fused_moegemm_kernel.hpp:212
const void * num_sorted_tiles_ptr
Definition fused_moegemm_kernel.hpp:204
const void * d_scale_ptr
Definition fused_moegemm_kernel.hpp:197
const void * sorted_token_ids_ptr
Definition fused_moegemm_kernel.hpp:201
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:160
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:162
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:158
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:159
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:161
static constexpr const char * name
Definition fused_moegemm_kernel.hpp:163
Definition fused_moegemm_kernel.hpp:157
Definition fused_moegemm_kernel.hpp:125
static constexpr bool UseUK
Definition fused_moegemm_kernel.hpp:149
typename Pipeline::Problem::ADataType ADataType
Definition fused_moegemm_kernel.hpp:135
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fused_moegemm_kernel.hpp:238
typename Pipeline::Problem::GDataType GDataType
Definition fused_moegemm_kernel.hpp:136
typename Pipeline::Problem::Traits Traits
Definition fused_moegemm_kernel.hpp:148
static constexpr bool PadIntermediateSize
Definition fused_moegemm_kernel.hpp:154
typename Pipeline::Problem::TopkWeightDataType TopkWeightDataType
Definition fused_moegemm_kernel.hpp:144
remove_cvref_t< Partitioner_ > Partitioner
Definition fused_moegemm_kernel.hpp:126
typename Pipeline::Problem::DDataType DDataType
Definition fused_moegemm_kernel.hpp:137
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition fused_moegemm_kernel.hpp:236
remove_cvref_t< Pipeline_ > Pipeline
Definition fused_moegemm_kernel.hpp:127
static constexpr bool UseSmoothQuant
Definition fused_moegemm_kernel.hpp:152
static constexpr index_t kBlockSize
Definition fused_moegemm_kernel.hpp:133
static CK_TILE_HOST constexpr auto BlockSize()
Definition fused_moegemm_kernel.hpp:234
typename Pipeline::Problem::DScaleDataType DScaleDataType
Definition fused_moegemm_kernel.hpp:142
typename Pipeline::Problem::AScaleDataType AScaleDataType
Definition fused_moegemm_kernel.hpp:140
typename Pipeline::Problem::GScaleDataType GScaleDataType
Definition fused_moegemm_kernel.hpp:141
typename Pipeline::Problem::AccDataType AccDataType
Definition fused_moegemm_kernel.hpp:138
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition fused_moegemm_kernel.hpp:225
typename Pipeline::Problem::ODataType ODataType
Definition fused_moegemm_kernel.hpp:139
typename Pipeline::Problem::IndexDataType IndexDataType
Definition fused_moegemm_kernel.hpp:145
static constexpr bool IsGateOnly
Definition fused_moegemm_kernel.hpp:151
static constexpr bool PadHiddenSize
Definition fused_moegemm_kernel.hpp:153
typename Pipeline::Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition fused_moegemm_kernel.hpp:143
remove_cvref_t< Epilogue_ > Epilogue
Definition fused_moegemm_kernel.hpp:128
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition fused_moegemm_kernel.hpp:219
typename Pipeline::Problem::YDataType YDataType
Definition fused_moegemm_kernel.hpp:146
typename Pipeline::BlockShape BlockShape
Definition fused_moegemm_kernel.hpp:132
static CK_TILE_HOST std::string GetName()
Definition fused_moegemm_kernel.hpp:166
Definition tile/core/container/sequence.hpp:49