13#if !defined(CK_TILE_HAS_ROW_NEWBCAST)
18#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__)
19#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__)
21#define CK_TILE_HAS_ROW_NEWBCAST 0
25#define CK_TILE_HAS_ROW_NEWBCAST 1
29#define CK_TILE_HAS_ROW_NEWBCAST 0
35#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
36 static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
38#ifndef MOE_SORTING_USE_EX_KERNEL
39#define MOE_SORTING_USE_EX_KERNEL 1
42#ifndef MOE_SORTING_FUSE_MP_01
43#define MOE_SORTING_FUSE_MP_01 1
47#ifndef MOE_SORTING_FMOE_2D_BUF
48#define MOE_SORTING_FMOE_2D_BUF 1
147 int smem_cols = num_experts_ + 1;
148 int smem_rows = [&](){
151 constexpr index_t sub_unroll = 8;
152 constexpr index_t cumsum_bufs = 2;
156 int r = total_ / target_occupancy_ / smem_cols;
159 if(r < (cumsum_bufs + sub_unroll))
163 int r_for_sub_token = r - cumsum_bufs;
164 r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll;
165 int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll;
166 r_for_sub_token =
min(r_for_sub_token, r_token_min);
169 if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) {
170 throw std::runtime_error(
"can't run this kernel, request LDS over size");
173 return r_for_sub_token + cumsum_bufs;
183 auto sub_token_ = r_ - 2;
210#if MOE_SORTING_FMOE_2D_BUF
224template <
typename Problem_>
252#if MOE_SORTING_FMOE_2D_BUF
269 hipDeviceProp_t dev_prop;
273 return dev_prop.multiProcessorCount;
280#if MOE_SORTING_FMOE_2D_BUF
291#if MOE_SORTING_USE_EX_KERNEL
302#if MOE_SORTING_USE_EX_KERNEL
304 return smem_rows * smem_cols *
sizeof(
index_t);
326#if MOE_SORTING_FMOE_2D_BUF
327 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
328 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
353 template <
typename data_t,
int wave_size>
357 constexpr int row_mask = 0xf;
358 constexpr int bank_mask = 0xf;
359 constexpr bool bound_ctrl =
true;
360 auto reduce_op = [&](
auto x_,
auto y_) {
return x_ + y_; };
362 if constexpr(wave_size > 1)
364 thread_data = reduce_op(
366 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
373 if constexpr(wave_size > 2)
375 thread_data = reduce_op(
377 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
383 if constexpr(wave_size > 4)
386 reduce_op(thread_data,
387 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
393 if constexpr(wave_size == 8) {
397 reduce_op(thread_data,
398 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
403#if CK_TILE_HAS_ROW_NEWBCAST
404 data_t xxx =__builtin_bit_cast(data_t,
405 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
411 data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
412 thread_data = thread_data - yyy;
416 int broadcast_src_lane = (__lane_id() & ~15) + 7;
417 int broadcast_addr = broadcast_src_lane << 2;
418 int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(
int, thread_data));
421 if ((__lane_id() / 8) % 2 != 0) {
422 thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
427 if constexpr(wave_size > 8)
430 reduce_op(thread_data,
431 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
438 if constexpr(wave_size > 16)
441 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(
int, thread_data));
442 v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
443 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
446 if constexpr(wave_size > 32)
449 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(
int, thread_data));
450 v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
451 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
456 template <
typename T,
typename F, index_t wave_size_ = get_warp_size()>
462 constexpr int reduce_stage = [](){
463 if constexpr(wave_size_ == 2)
return 1;
464 else if constexpr(wave_size_ == 4)
return 2;
465 else if constexpr(wave_size_ == 8)
return 3;
466 else if constexpr(wave_size_ == 16)
return 4;
467 else if constexpr(wave_size_ == 32)
return 5;
468 else if constexpr(wave_size_ == 64)
return 6;
473#pragma unroll reduce_stage
474 for(
int i_stage = 0; i_stage < reduce_stage; i_stage++)
476 int src_lane = __lane_id() ^ (1 << i_stage);
480 v_local = reduce_f(v_local, v_remote);
487 return row * total_col + col;
492 const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
493 if(
offset < buf_bytes / 16)
503 const long_index_t total_bytes = total_pixels * elem_bytes;
507 vector_type* p_buf =
reinterpret_cast<vector_type*
>(buf);
508 auto zero_ = vector_type{0};
522 index_t* p_total_tokens_post_pad,
524 const index_t tokens_per_thread,
526 const mdiv unit_size_mdiv,
527 const mdiv topk_mdiv,
531 const index_t start_idx = tid * tokens_per_thread;
535 index_t* tokens_cnts = shared_mem;
536 index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1);
538 for(
int i = 0; i < num_experts; ++i)
540 tokens_cnts[
calc_index(num_experts + 1, tid + 1, i)] = 0;
543#pragma unroll Problem_::InternalLoadUnroll
544 for(
int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
546 ++tokens_cnts[
calc_index(num_experts + 1, tid + 1, topk_id[i])];
550#if MOE_SORTING_FUSE_MP_01
551 if(tid < num_experts)
553 tokens_cnts[
calc_index(num_experts + 1, 0, tid)] = 0;
557 for(
int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
559 local_c[0] = tokens_cnts[
calc_index(num_experts + 1, i + 0, tid)];
560 local_c[1] = tokens_cnts[
calc_index(num_experts + 1, i + 1, tid)];
561 local_c[2] = tokens_cnts[
calc_index(num_experts + 1, i + 2, tid)];
562 local_c[3] = tokens_cnts[
calc_index(num_experts + 1, i + 3, tid)];
563 local_c[4] = tokens_cnts[
calc_index(num_experts + 1, i + 4, tid)];
564 local_c[5] = tokens_cnts[
calc_index(num_experts + 1, i + 5, tid)];
565 local_c[6] = tokens_cnts[
calc_index(num_experts + 1, i + 6, tid)];
566 local_c[7] = tokens_cnts[
calc_index(num_experts + 1, i + 7, tid)];
568 local_c[0] += prev_c;
569 local_c[1] += local_c[0];
570 local_c[2] += local_c[1];
571 local_c[3] += local_c[2];
572 local_c[4] += local_c[3];
573 local_c[5] += local_c[4];
574 local_c[6] += local_c[5];
575 local_c[7] += local_c[6];
578 tokens_cnts[
calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
579 tokens_cnts[
calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
580 tokens_cnts[
calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
581 tokens_cnts[
calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
582 tokens_cnts[
calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
583 tokens_cnts[
calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
584 tokens_cnts[
calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
585 tokens_cnts[
calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
592 if(tid < num_experts)
593 tokens_cnts[
calc_index(num_experts + 1, 0, tid)] = 0;
594 for(
int i = 0; i < num_experts; i += 8)
598 for(
int j = 0; j < 8; j++)
600 local_c[j] = tokens_cnts[
calc_index(num_experts + 1, tid + 1, i + j)];
604 for(
int j = 0; j < 8; j++)
610 for(
int j = 0; j < 8; j++)
612 tokens_cnts[
calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
619 if constexpr(Problem::ExpertTile == 0)
624 for(
int i = 1; i <= num_experts; ++i)
626 auto current_units = [&]() {
632 cumsum[i] = cumsum[i - 1] + current_units;
634 *p_total_tokens_post_pad = cumsum[num_experts];
641 int local_cnt = tokens_cnts[
calc_index(num_experts + 1, blockDim.x, tid)];
642 int blocks_pers_expert = unit_size_mdiv.
div(local_cnt + unit_size_mdiv.
divisor - 1);
643 int padded_tokens_per_expert =
max(blocks_pers_expert, 1) * unit_size_mdiv.
divisor;
644 int local_cumsum = padded_tokens_per_expert;
647 if(tid == (num_experts - 1))
650 *p_total_tokens_post_pad = local_cumsum;
652 if(tid < num_experts)
654 cumsum[tid + 1] = local_cumsum;
659 if(tid < num_experts)
661 int e_start = cumsum[tid];
662 int e_end = cumsum[tid + 1];
663 for(
int i = e_start; i < e_end; i += unit_size_mdiv.
divisor)
665 p_sorted_expert_ids[unit_size_mdiv.
div(i)] = tid;
669#pragma unroll Problem_::InternalLoadUnroll
670 for(
int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
672 index_t expert_id = topk_id[i];
674 index_t rank_post_pad = local_cnt + cumsum[expert_id];
675#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
676 uint32_t curr_token_id, curr_topk_id;
677 topk_mdiv.
divmod(i, curr_token_id, curr_topk_id);
680 p_sorted_token_ids[rank_post_pad] = topk_mdiv.
div(i);
682 p_sorted_weights[rank_post_pad] = weights[i];
683 tokens_cnts[
calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
686 if constexpr(Problem::ExpertTile == 0)
688 const index_t prefill_token = topk_mdiv.
div(numel);
689 if(tid < num_experts)
692 cumsum[tid] + tokens_cnts[
calc_index(num_experts + 1, blockDim.x, tid)];
693 index_t expert_end = cumsum[tid + 1];
694 while(expert_offset < expert_end)
696#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
697 p_sorted_token_ids[expert_offset] =
700 p_sorted_token_ids[expert_offset] = prefill_token;
702 p_sorted_weights[expert_offset] =
static_cast<WeightType>(0.0);
709 const index_t prefill_token = topk_mdiv.
div(numel);
713 index_t eid = tid / experts_per_wave;
714 index_t expert_offset = cumsum[eid] +
715 tokens_cnts[
calc_index(num_experts + 1, blockDim.x, eid)] +
716 tid % experts_per_wave;
717 index_t expert_end = cumsum[eid + 1];
718 if(eid < num_experts)
720 while(expert_offset < expert_end)
722#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
723 p_sorted_token_ids[expert_offset] =
726 p_sorted_token_ids[expert_offset] = prefill_token;
728 p_sorted_weights[expert_offset] =
static_cast<WeightType>(0.0);
729 expert_offset += experts_per_wave;
765 const IndexType* __restrict__ local_expert_mask,
769 index_t* p_total_tokens_post_pad,
772 const mdiv unit_size_mdiv,
773 const mdiv topk_mdiv,
774 const mdiv expert_mdiv,
780 const index_t lid = __lane_id();
781 constexpr index_t block_size = 256;
782 const index_t sub_tokens = smem_rows - 2;
784 auto f_sum = [](
auto x_,
auto y_) {
return x_ + y_; };
786 const index_t smem_cols = num_experts + 1;
794 for(
int i = tid; i < (sub_tokens * num_experts); i += block_size)
796 uint32_t curr_token_id, curr_expert_id;
797 expert_mdiv.
divmod(i, curr_token_id, curr_expert_id);
798 smem_tokens(curr_token_id, curr_expert_id) = 0;
802 for(
int i_token = 0; i_token < tokens; i_token += sub_tokens)
805 for(
int i = tid; i < (sub_tokens * topk); i += block_size)
807 uint32_t curr_token_id, curr_topk_id;
808 topk_mdiv.
divmod(i, curr_token_id, curr_topk_id);
809 int i_t = i_token + curr_token_id;
813 int eid = topk_id[i_t * topk + curr_topk_id];
815 if constexpr(Problem::SubTokenOneShot)
816 smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
818 smem_tokens(curr_token_id, eid)++;
833 constexpr int lane_group_sz = 8;
834 int lane_group_id = tid / lane_group_sz;
835 int lane_group_os = tid % lane_group_sz;
836 constexpr int lane_group_nm = block_size / lane_group_sz;
838 for(
int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
840 index_t local_c[Problem::SubTokenTile];
843 for(
int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
845#pragma unroll Problem::SubTokenTile
846 for(
int j = 0; j < Problem::SubTokenTile; j++)
848 local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
849 if constexpr(Problem::SubTokenOneShot)
851 local_c[j] = local_c[j] != 0 ? 1 : 0;
855#pragma unroll Problem::SubTokenTile
856 for(
int j = 0; j < Problem::SubTokenTile; j++)
861 if(lane_group_os == 0)
862 smem_cumsum(i_e + 1) = cnt;
866 if constexpr(Problem::LocalExpertMasking)
869 for(
int i_e = tid; i_e < num_experts; i_e += block_size)
872 smem_cumdup(i_e + 1) = local_expert_mask[i_e];
883 int local_cumsum_ = 0;
886 int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
887 int local_cnt = smem_cumsum(i_e_ + lid + 1);
888 int blocks_pers_expert =
889 unit_size_mdiv.
div(local_cnt + unit_size_mdiv.
divisor - 1);
891 int pre_cumsum_masking = [&]() {
892 if constexpr(Problem::LocalExpertMasking)
893 return smem_cumdup(lid == 0 ? i_e_ : 0);
897 int local_masking = [&]() {
898 if constexpr(Problem::LocalExpertMasking)
899 return smem_cumdup(i_e_ + lid + 1);
903 int padded_tokens_per_expert = [&]() {
905 if constexpr(Problem::SkipExpertsWithZeroTokens)
909 return blocks_pers_expert * unit_size_mdiv.
divisor;
913 return max(blocks_pers_expert, 1) * unit_size_mdiv.
divisor;
916 if constexpr(Problem::LocalExpertMasking)
918 return local_masking ? x_ : 0;
924 local_cumsum_ = padded_tokens_per_expert;
925 local_cumsum_ += pre_cumsum_;
931 if((i_e_ + lid) < num_experts)
932 smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
934 if constexpr(Problem::LocalExpertMasking)
936 local_masking += pre_cumsum_masking;
938 if((i_e_ + lid) < num_experts)
939 smem_cumdup(i_e_ + lid + 1) = local_masking;
949 *p_total_tokens_post_pad = local_cumsum_;
950 p_total_tokens_post_pad[1] = tokens;
956 for(
int i_e = tid; i_e < num_experts; i_e += block_size)
958 int e_start = smem_cumsum(i_e);
959 int e_end = smem_cumsum(i_e + 1);
961 int expert_id = [&]() {
962 if constexpr(Problem::LocalExpertMasking)
965 return smem_cumdup(i_e);
971 smem_cumdup(i_e) = e_start;
972 if constexpr(Problem::SkipExpertsWithZeroTokens)
978 if constexpr(Problem::LocalExpertMasking)
980 if(local_expert_mask[i_e] == 0)
984 for(
int i = e_start; i < e_end; i += unit_size_mdiv.
divisor)
986 p_sorted_expert_ids[unit_size_mdiv.
div(i)] = expert_id;
989 smem_cumdup(num_experts) = smem_cumsum(num_experts);
992 for(
int i_token = 0; i_token < tokens; i_token += sub_tokens)
994 if constexpr(!Problem::SubTokenOneShot)
997 for(
int i = tid; i < (sub_tokens * num_experts); i += block_size)
999 uint32_t curr_token_id, curr_expert_id;
1000 expert_mdiv.
divmod(i, curr_token_id, curr_expert_id);
1001 smem_tokens(curr_token_id, curr_expert_id) = 0;
1006 for(
int i = tid; i < (sub_tokens * topk); i += block_size)
1008 uint32_t curr_token_id_, curr_topk_id_;
1009 topk_mdiv.
divmod(i, curr_token_id_, curr_topk_id_);
1010 int curr_token_id =
static_cast<int>(curr_token_id_);
1011 int curr_topk_id =
static_cast<int>(curr_topk_id_);
1012 int i_t = i_token + curr_token_id;
1015 int eid = topk_id[i_t * topk + curr_topk_id];
1016 smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
1023 constexpr int lane_group_sz = 8;
1024 int lane_group_id = tid / lane_group_sz;
1025 int lane_group_os = tid % lane_group_sz;
1026 constexpr int lane_group_nm = block_size / lane_group_sz;
1027 for(
int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
1029 if constexpr(Problem::LocalExpertMasking)
1031 if(local_expert_mask[eid] == 0)
1034 int position = smem_cumsum(eid);
1035 for(
int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
1036 i_sub_token += lane_group_sz)
1038 auto x = smem_tokens(i_sub_token, eid);
1040 int local_cnt_cache = x != 0 ? 1 : 0;
1041 int local_cnt = local_cnt_cache;
1046#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
1047 p_sorted_token_ids[position + local_cnt - 1] =
1050 p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
1052 p_sorted_weights[position + local_cnt - 1] =
1053 weights[(i_token + i_sub_token) * topk + x - 1];
1056 int remote_cnt = __builtin_amdgcn_ds_bpermute(
1057 (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
1059 position += remote_cnt;
1061 smem_cumsum(eid) = position;
1068 for(
int eid = tid; eid < num_experts; eid += block_size)
1070 int e_start = smem_cumsum(eid);
1071 int e_end = smem_cumdup(eid + 1);
1072 if constexpr(Problem::SkipExpertsWithZeroTokens)
1074 if(e_start == e_end)
1077 while(e_start < e_end)
1079#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
1082 p_sorted_token_ids[e_start] = tokens;
1084 p_sorted_weights[e_start] =
static_cast<WeightType>(0.0);
1093 if constexpr(Problem::LocalToken)
1107#if MOE_SORTING_FMOE_2D_BUF
1109 kargs.
p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
1118 extern __shared__
char smem[];
1120#if MOE_SORTING_USE_EX_KERNEL
1162 return (tokens + chunk - 1) / chunk * chunk;
1187 index_t elem = num_experts * row_size;
1194 index_t row_size = num_experts + 1;
1195 return (row_size + chunk - 1) / chunk * chunk *
sizeof(
index_t);
1201 return chunk *
sizeof(
index_t);
1204template <
typename T,
typename F, index_t wave_size_ = get_warp_size()>
1210 constexpr int reduce_stage = [](){
1211 if constexpr(wave_size_ == 2)
return 1;
1212 else if constexpr(wave_size_ == 4)
return 2;
1213 else if constexpr(wave_size_ == 8)
return 3;
1214 else if constexpr(wave_size_ == 16)
return 4;
1215 else if constexpr(wave_size_ == 32)
return 5;
1216 else if constexpr(wave_size_ == 64)
return 6;
1221#pragma unroll reduce_stage
1222 for(
int i_stage = 0; i_stage < reduce_stage; i_stage++)
1224 int src_lane = __lane_id() ^ (1 << i_stage);
1228 v_local = reduce_f(v_local, v_remote);
1235template <
typename data_t,
int wave_size>
1239 constexpr int row_mask = 0xf;
1240 constexpr int bank_mask = 0xf;
1241 constexpr bool bound_ctrl =
true;
1242 auto reduce_op = [&](
auto x_,
auto y_) {
return x_ + y_; };
1244 if constexpr(wave_size > 1)
1246 thread_data = reduce_op(
1248 __builtin_bit_cast(data_t,
1249 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1256 if constexpr(wave_size > 2)
1258 thread_data = reduce_op(
1260 __builtin_bit_cast(data_t,
1261 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1267 if constexpr(wave_size > 4)
1269 thread_data = reduce_op(
1271 __builtin_bit_cast(data_t,
1272 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1278 if constexpr(wave_size == 8)
1282 thread_data = reduce_op(
1284 __builtin_bit_cast(data_t,
1285 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1290#if CK_TILE_HAS_ROW_NEWBCAST
1292 __builtin_bit_cast(data_t,
1293 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1299 data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
1300 thread_data = thread_data - yyy;
1304 int broadcast_src_lane = (__lane_id() & ~15) + 7;
1305 int broadcast_addr = broadcast_src_lane << 2;
1307 __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(
int, thread_data));
1310 if((__lane_id() / 8) % 2 != 0)
1312 thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
1316 if constexpr(wave_size > 8)
1318 thread_data = reduce_op(
1320 __builtin_bit_cast(data_t,
1321 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(
int, thread_data),
1328 if constexpr(wave_size > 16)
1331 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2,
1332 __builtin_bit_cast(
int, thread_data));
1333 v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
1334 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
1337 if constexpr(wave_size > 32)
1340 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2,
1341 __builtin_bit_cast(
int, thread_data));
1342 v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
1343 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
1347template <index_t kBlockSize = 256>
1352 if(
offset < buf_bytes / 16)
1358template <index_t kBlockSize = 256>
1363 const long_index_t total_bytes = total_pixels * elem_bytes;
1367 vector_type* p_buf =
reinterpret_cast<vector_type*
>(buf);
1368 auto zero_ = vector_type{0};
1370 for(
long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize)
1382#if CK_TILE_WA_ISSUE_2028
1383 if(tokens_ >= 65536 * 2)
1389 bool is_sub_token_onshot = tokens_ <= sub_token_;
1390 return is_sub_token_onshot;
1398#if MOE_SORTING_FUSE_MP_01
1412 int dispatch_policy_)
1416 if(dispatch_policy_ == 0)
1427 else if(dispatch_policy_ == 1)
1440template <
typename Problem_>
1463 hipDeviceProp_t dev_prop;
1467 return dev_prop.multiProcessorCount;
1494 if constexpr(Problem::LocalToken)
1505 if constexpr(Problem::LocalToken)
1515 index_t row_size = mesh_stride;
1518 index_t total_elems = total_bytes / 16;
1521 vector_type* p_expert_mesh =
reinterpret_cast<vector_type*
>(kargs.
p_expert_mesh);
1522 auto zero_ = vector_type{0};
1527 p_expert_mesh[i] = zero_;
1576template <
typename Problem_>
1607 hipDeviceProp_t dev_prop;
1611 return dev_prop.multiProcessorCount;
1640 const topk_id_t* p_topk_ids =
reinterpret_cast<const topk_id_t*
>(kargs.
p_topk_ids);
1643 if constexpr(Problem::LocalToken)
1652 index_t rounded_tokens = [&]() {
1653 if constexpr(Problem::LocalToken)
1655 return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
1656 Problem::SubTokenTile;
1662 if constexpr(Problem::LocalToken)
1673#pragma unroll Problem::SubTokenTile
1677 auto x = p_topk_ids[i];
1680 uint32_t curr_token_id, curr_topk_id;
1681 kargs.
topk_mdiv.
divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
1684 if constexpr(Problem::LocalToken)
1686 if(
static_cast<index_t>(curr_token_id) < tokens)
1687 p_expert_mesh[eid * mesh_stride + curr_token_id] =
1688 (curr_topk_id + 1) & 0xffff;
1691 p_expert_mesh[eid * mesh_stride + curr_token_id] =
1692 (curr_topk_id + 1) & 0xffff;
1698template <
typename Problem_>
1731 hipDeviceProp_t dev_prop;
1735 return dev_prop.multiProcessorCount;
1747 reinterpret_cast<char*
>(h.
p_ws) +
1770 constexpr index_t index_pack = Problem::SubTokenTile;
1773 const int eid = blockIdx.x;
1774 const topk_id_t* p_topk_ids =
reinterpret_cast<const topk_id_t*
>(kargs.
p_topk_ids);
1780 const index_t tokens = [&]() {
1781 if constexpr(Problem::LocalToken)
1790 index_t rounded_tokens = [&]() {
1791 if constexpr(Problem::LocalToken)
1793 return (tokens + index_pack - 1) / index_pack * index_pack;
1799 if constexpr(Problem::LocalToken)
1810 if constexpr(Problem::LocalExpertMasking)
1812 mask = p_local_expert_mask[eid];
1818 p_expert_mesh[i] = 0;
1824#pragma unroll index_pack
1827 auto x = p_topk_ids[i];
1832 uint32_t curr_token_id, curr_topk_id;
1833 kargs.
topk_mdiv.
divmod(i * index_pack + j, curr_token_id, curr_topk_id);
1834 if constexpr(Problem::LocalToken)
1836 if(
static_cast<index_t>(curr_token_id) < tokens)
1837 p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
1840 p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
1849 auto f_sum = [](
auto x_,
auto y_) {
return x_ + y_; };
1850 const r_t* p_expert_mesh_r =
reinterpret_cast<r_t*
>(p_expert_mesh);
1854 if(Problem::LocalToken && mask == 0)
1857 for(
int i = 0; i < loops; i++)
1861 if(position < (mesh_stride / index_pack))
1862 v = p_expert_mesh_r[position];
1865 [&](
auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
1877 if(threadIdx.x == 0)
1884 p_expert_cumsum[eid] = c;
1891template <
typename Problem_>
1922 reinterpret_cast<char*
>(h.
p_ws) +
1943 int eid = blockIdx.x;
1944 constexpr index_t index_pack = Problem::SubTokenTile;
1951 auto f_sum = [](
auto x_,
auto y_) {
return x_ + y_; };
1954 if constexpr(Problem::LocalToken)
1965 if constexpr(Problem::LocalToken)
1975 r_t* p_expert_mesh =
reinterpret_cast<r_t*
>(
1980 if constexpr(Problem::LocalExpertMasking)
1982 IndexType mask = p_local_expert_mask[eid];
1988 for(
int i = 0; i < loops; i++)
1992 if(position < (mesh_stride / index_pack))
1993 v = p_expert_mesh[position];
1996 [&](
auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
2011 if(threadIdx.x == 0)
2018 p_expert_cumsum[eid] = c;
2023#if MOE_SORTING_FUSE_MP_01
2024template <
typename Problem_>
2025struct MoeSortingMultiPhaseKernel_P01
2027 using Problem = remove_cvref_t<Problem_>;
2029 using IndexType =
typename Problem::IndexType;
2030 using WeightType =
typename Problem::WeightType;
2031 using MeshType =
typename Problem::MeshType;
2033 static constexpr index_t kBlockSize = 256;
2034 static constexpr index_t OCCUPANCY = 2;
2036 typedef MoeSortingHostArgs MoeSortingKargs;
2038 using Hargs = MoeSortingHostArgs;
2042 const void* p_topk_ids;
2043 const void* p_local_expert_mask;
2044 const void* p_local_tokens;
2045 void* p_expert_mesh;
2046 void* p_expert_cumsum;
2049 index_t num_experts;
2050 index_t mesh_stride;
2058 hipDeviceProp_t dev_prop;
2062 return dev_prop.multiProcessorCount;
2067 CK_TILE_HOST static constexpr auto MakeKargs(
const Hargs& h)
2070 k.p_topk_ids = h.p_topk_ids;
2071 k.p_local_expert_mask = h.p_local_expert_mask;
2072 k.p_local_tokens = h.p_local_tokens;
2073 k.p_expert_mesh = h.p_ws;
2074 k.p_expert_cumsum =
reinterpret_cast<void*
>(
2075 reinterpret_cast<char*
>(h.p_ws) +
2076 impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
2077 k.p_expert_sem =
reinterpret_cast<void*
>(
2078 reinterpret_cast<char*
>(h.p_ws) +
2079 impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) +
2080 impl::moe_sorting_mp_cumsum_smem_size(h.num_experts));
2081 k.tokens = h.tokens;
2082 k.num_experts = h.num_experts;
2083 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
2084 k.wg_count = [&]() {
2085 if constexpr(Problem::LocalToken)
2094 k.topk_mdiv = mdiv{
static_cast<uint32_t>(h.topk)};
2098 CK_TILE_HOST static constexpr auto GridSize(
const Hargs&) {
return get_num_cu() * OCCUPANCY; }
2100 CK_TILE_HOST static constexpr auto BlockSize(
const Hargs&) {
return dim3(kBlockSize); }
2102 CK_TILE_HOST static constexpr auto WGCounts(
const Hargs& h)
2104 index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
2105 index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
2108 return min(elem_cnt, GridSize(h));
2119 workgroup_barrier wb{
reinterpret_cast<uint32_t*
>(kargs.p_expert_sem)};
2121 if constexpr(Problem::LocalToken)
2123 return reinterpret_cast<const index_t*
>(kargs.p_local_tokens)[0];
2127 return kargs.tokens;
2130 index_t rounded_tokens = [&]() {
2131 if constexpr(Problem::LocalToken)
2133 return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
2134 Problem::SubTokenTile;
2140 if constexpr(Problem::LocalToken)
2142 index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
2143 index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
2146 return min(elem_cnt, kargs.wg_count);
2150 return kargs.wg_count;
2155 using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
2157 const topk_id_t* p_topk_ids =
reinterpret_cast<const topk_id_t*
>(kargs.p_topk_ids);
2158 IndexType* p_expert_mesh =
reinterpret_cast<IndexType*
>(kargs.p_expert_mesh);
2159 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
2161#pragma unroll Problem::SubTokenTile
2162 for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
2163 i += kBlockSize * gridDim.x)
2165 auto x = p_topk_ids[i];
2166 static_for<0, Problem::SubTokenTile, 1>{}([&](
auto j) {
2167 IndexType eid = x[j.value];
2168 uint32_t curr_token_id, curr_topk_id;
2169 kargs.topk_mdiv.divmod(
2170 i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
2172 if constexpr(Problem::LocalToken)
2174 if(
static_cast<index_t>(curr_token_id) < tokens)
2175 p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
2176 (curr_topk_id + 1) & 0xffff;
2179 p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
2180 (curr_topk_id + 1) & 0xffff;
2183 if(
static_cast<index_t>(blockIdx.x) < wg_count)
2190 __shared__
char smem[GetSmemSize()];
2191 int eid = blockIdx.x;
2194 if(eid >= kargs.num_experts)
2197 wb.wait_lt(wg_count);
2199 for(; eid < kargs.num_experts; eid += gridDim.x)
2207 constexpr index_t index_pack = 4;
2208 using r_t = ext_vector_t<IndexType, index_pack>;
2209 r_t* p_expert_mesh =
reinterpret_cast<r_t*
>(
2210 reinterpret_cast<index_t*
>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
2212 const IndexType* p_local_expert_mask =
2213 static_cast<const IndexType*
>(kargs.p_local_expert_mask);
2214 IndexType* p_expert_cumsum =
reinterpret_cast<IndexType*
>(kargs.p_expert_cumsum);
2216 auto f_sum = [](
auto x_,
auto y_) {
return x_ + y_; };
2218 int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
2220 if constexpr(Problem::LocalExpertMasking)
2222 IndexType mask = p_local_expert_mask[eid];
2228 for(
int i = 0; i < loops; i++)
2230 int position = i * kBlockSize + threadIdx.x;
2232 if(position < (kargs.mesh_stride / index_pack))
2233 v = p_expert_mesh[position];
2235 static_for<0, index_pack, 1>{}(
2236 [&](
auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
2237 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
2244 IndexType* s =
reinterpret_cast<IndexType*
>(smem);
2252 if(threadIdx.x == 0)
2259 p_expert_cumsum[eid] = c;
2268template <
typename Problem_>
2304 k.p_expert_cumsum =
reinterpret_cast<void*
>(
2305 reinterpret_cast<char*
>(h.
p_ws) +
2317#if MOE_SORTING_FMOE_2D_BUF
2318 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
2319 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
2330 hipDeviceProp_t dev_prop;
2334 return dev_prop.multiProcessorCount;
2341#if MOE_SORTING_FMOE_2D_BUF
2363#if MOE_SORTING_FMOE_2D_BUF
2366 kargs.moe_buf_interm_dim,
2367 kargs.moe_buf_elem_bytes,
2373 reinterpret_cast<uint8x16_t*
>(kargs.p_moe_buf),
2374 kargs.moe_buf_bytes,
2383 static_cast<const IndexType*
>(kargs.p_local_expert_mask);
2386 reinterpret_cast<IndexType*
>(kargs.p_total_tokens_post_pad);
2387 IndexType* p_sorted_expert_ids =
reinterpret_cast<IndexType*
>(kargs.p_sorted_expert_ids);
2396 for(
index_t i = 0; i < loops; i++)
2401 if(position < kargs.num_experts)
2403 a_ = p_expert_cumsum[position];
2404 if constexpr(Problem::LocalExpertMasking)
2405 b_ = p_local_expert_mask[position];
2408 int blocks_pers_expert =
2409 kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
2411 int padded_blocks_per_expert = [&]() {
2413 if constexpr(Problem::SkipExpertsWithZeroTokens)
2417 return blocks_pers_expert;
2421 return max(blocks_pers_expert, 1);
2424 if constexpr(Problem::LocalExpertMasking)
2432 IndexType cumsum_a = padded_blocks_per_expert;
2442 s[4 + wave_id] = cumsum_a;
2452 prev_a = wave_id > i_w ? prev_a : 0;
2453 prev_b = wave_id > i_w ? prev_b : 0;
2459 cumsum_a += prev_cumsum_a;
2460 cumsum_b += prev_cumsum_b;
2468 IndexType out_0 = cumsum_a - padded_blocks_per_expert;
2472 prev_cumsum_a = s[2];
2473 prev_cumsum_b = s[3];
2475 if(position < kargs.num_experts)
2477 p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor;
2481 if constexpr(Problem::LocalExpertMasking)
2485 for(
int j = 0; j < blocks_pers_expert; j++)
2487 p_sorted_expert_ids[out_0 + j] = out_1;
2493 for(
int j = 0; j < blocks_pers_expert; j++)
2495 p_sorted_expert_ids[out_0 + j] = position;
2501 if(threadIdx.x == 0)
2503 auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
2504 p_total_tokens_post_pad[0] = total_tokens_post_pad;
2505 p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad;
2510template <
typename Problem_>
2550 k.p_expert_mesh = h.
p_ws;
2551 k.p_expert_cumsum =
reinterpret_cast<void*
>(
2552 reinterpret_cast<char*
>(h.
p_ws) +
2576 static_cast<const IndexType*
>(kargs.p_local_expert_mask);
2579 IndexType* p_sorted_token_ids =
reinterpret_cast<IndexType*
>(kargs.p_sorted_token_ids);
2585 if constexpr(Problem::LocalToken)
2587 return reinterpret_cast<const index_t*
>(kargs.p_local_tokens)[0];
2591 return kargs.tokens;
2594 int eid = blockIdx.x;
2597 int e_start = p_expert_cumsum[eid];
2598 int e_end = p_expert_cumsum[eid + 1];
2599 if constexpr(Problem::SkipExpertsWithZeroTokens)
2601 if(e_start == e_end)
2605 if constexpr(Problem::LocalExpertMasking)
2607 int e_mask = p_local_expert_mask[eid];
2614 int prev_cumsum = 0;
2615 for(
int i = 0; i < loops; i++)
2619 if(i_token < tokens)
2621 x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
2624 int i_show = x != 0 ? 1 : 0;
2625 int cumsum = i_show;
2631 s[4 + wave_id] = cumsum;
2638 prev = wave_id > i_w ? prev : 0;
2641 cumsum += prev_cumsum;
2648 int position = cumsum - i_show;
2653#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
2656 p_sorted_token_ids[e_start + position] = i_token;
2658 p_sorted_weights[e_start + position] =
2659 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
2663 for(
index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i +=
kBlockSize)
2665#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
2668 p_sorted_token_ids[i] = tokens;
2670 p_sorted_weights[i] =
static_cast<WeightType>(0.0);
2679 constexpr index_t kBlockSize = 256;
2680 const index_t expert_cumsum_elem = num_experts_ + 1;
2681 return (4 + 2 * kBlockSize /
get_warp_size() + expert_cumsum_elem) *
sizeof(int);
2686template <
typename Problem_>
2720#if MOE_SORTING_FMOE_2D_BUF
2739 k.p_expert_mesh = h.
p_ws;
2740 k.p_expert_cumsum =
reinterpret_cast<void*
>(
2741 reinterpret_cast<char*
>(h.
p_ws) +
2757#if MOE_SORTING_FMOE_2D_BUF
2758 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
2759 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
2770 hipDeviceProp_t dev_prop;
2774 return dev_prop.multiProcessorCount;
2781#if MOE_SORTING_FMOE_2D_BUF
2797 return max(smem_23, smem_sf);
2804 if constexpr(Problem::LocalToken)
2806 return reinterpret_cast<const index_t*
>(kargs.p_local_tokens)[0];
2810 return kargs.tokens;
2814 if(
static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
2816#if MOE_SORTING_FMOE_2D_BUF
2819 kargs.moe_buf_interm_dim,
2820 kargs.moe_buf_elem_bytes,
2821 blockIdx.x - kargs.num_experts,
2822 gridDim.x - kargs.num_experts);
2826 reinterpret_cast<uint8x16_t*
>(kargs.p_moe_buf),
2827 kargs.moe_buf_bytes,
2828 blockIdx.x - kargs.num_experts);
2833 extern __shared__
char smem[];
2838 static_cast<const IndexType*
>(kargs.p_local_expert_mask);
2842 reinterpret_cast<IndexType*
>(kargs.p_total_tokens_post_pad);
2844 reinterpret_cast<IndexType*
>(kargs.p_sorted_expert_ids);
2853 for(
index_t i = 0; i < loops; i++)
2858 if(position < kargs.num_experts)
2860 a_ = p_expert_cumsum[position];
2861 if constexpr(Problem::LocalExpertMasking)
2862 b_ = p_local_expert_mask[position];
2865 int blocks_pers_expert =
2866 kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
2868 int padded_blocks_per_expert = [&]() {
2870 if constexpr(Problem::SkipExpertsWithZeroTokens)
2874 return blocks_pers_expert;
2878 return max(blocks_pers_expert, 1);
2881 if constexpr(Problem::LocalExpertMasking)
2889 IndexType cumsum_a = padded_blocks_per_expert;
2899 s[4 + wave_id] = cumsum_a;
2909 prev_a = wave_id > i_w ? prev_a : 0;
2910 prev_b = wave_id > i_w ? prev_b : 0;
2916 cumsum_a += prev_cumsum_a;
2917 cumsum_b += prev_cumsum_b;
2925 IndexType out_0 = cumsum_a - padded_blocks_per_expert;
2929 prev_cumsum_a = s[2];
2930 prev_cumsum_b = s[3];
2932 if(position < kargs.num_experts)
2934 p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor;
2940 if constexpr(Problem::LocalExpertMasking)
2944 for(
int j = 0; j < blocks_pers_expert; j++)
2946 p_sorted_expert_ids[out_0 + j] = out_1;
2952 for(
int j = 0; j < blocks_pers_expert; j++)
2954 p_sorted_expert_ids[out_0 + j] = position;
2961 if(threadIdx.x == 0)
2963 auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
2966 p_total_tokens_post_pad[0] = total_tokens_post_pad;
2967 p_total_tokens_post_pad[1] = tokens;
2969 p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
2976 static_cast<const IndexType*
>(kargs.p_local_expert_mask);
2978 MeshType* p_expert_mesh =
reinterpret_cast<MeshType*
>(kargs.p_expert_mesh);
2979 IndexType* p_sorted_token_ids =
reinterpret_cast<IndexType*
>(kargs.p_sorted_token_ids);
2984 int eid = blockIdx.x;
2987 int e_start = p_expert_cumsum_smem[eid];
2988 int e_end = p_expert_cumsum_smem[eid + 1];
2989 if constexpr(Problem::SkipExpertsWithZeroTokens)
2991 if(e_start == e_end)
2995 if constexpr(Problem::LocalExpertMasking)
2997 int e_mask = p_local_expert_mask[eid];
3003 if constexpr(Problem::LocalToken)
3009 return kargs.mesh_stride;
3014 constexpr index_t index_pack = Problem::SubTokenTile;
3019 int prev_cumsum = 0;
3021 for(
int i = 0; i < loops; i++)
3023 int i_token_pack = i *
kBlockSize + threadIdx.x;
3025 if(i_token_pack < (tokens + index_pack - 1) / index_pack)
3027 x_v =
reinterpret_cast<r_t*
>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
3032 if constexpr(index_pack != 1)
3036 reinterpret_cast<r_t*
>(s)[threadIdx.x] = x_v;
3040 constexpr auto j = j_.value;
3050 for(
int j = 0; j < index_pack / 2; j++)
3055 int i_show = x != 0 ? 1 : 0;
3056 int cumsum = i_show;
3062 s[4 + wave_id] = cumsum;
3069 prev = wave_id > i_w ? prev : 0;
3072 cumsum += prev_cumsum;
3079 int position = cumsum - i_show;
3084#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3085 p_sorted_token_ids[e_start + position] =
3088 p_sorted_token_ids[e_start + position] = i_token;
3090 p_sorted_weights[e_start + position] =
3091 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
3099 int cumsum_store = 0;
3102 constexpr auto j = j_.value;
3103 i_topk[j] =
static_cast<index_t>(x_r[j] - 1);
3104 i_show[j] =
static_cast<index_t>(x_r[j] != 0 ? 1 : 0);
3105 cumsum_store += i_show[j];
3107 int cumsum = cumsum_store;
3113 s[4 + wave_id] = cumsum;
3120 prev = wave_id > i_w ? prev : 0;
3123 cumsum += prev_cumsum;
3131 int position = cumsum - cumsum_store;
3133 constexpr auto j = j_.value;
3137 i *
kBlockSize * index_pack + threadIdx.x * index_pack + j;
3141#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3142 p_sorted_token_ids[e_start + position] =
3145 p_sorted_token_ids[e_start + position] = i_token;
3147 p_sorted_weights[e_start + position] =
3148 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]];
3150 position += i_show[j];
3158 int i_topk_0 = x0 - 1;
3159 int i_show_0 = x0 != 0 ? 1 : 0;
3160 int i_topk_1 = x1 - 1;
3161 int i_show_1 = x1 != 0 ? 1 : 0;
3162 int cumsum = i_show_0 + i_show_1;
3168 s[4 + wave_id] = cumsum;
3175 prev = wave_id > i_w ? prev : 0;
3178 cumsum += prev_cumsum;
3185 int position_0 = cumsum - i_show_0 - i_show_1;
3190#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3191 p_sorted_token_ids[e_start + position_0] =
3194 p_sorted_token_ids[e_start + position_0] = i_token;
3196 p_sorted_weights[e_start + position_0] =
3197 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
3200 int position_1 = cumsum - i_show_1;
3204#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3205 p_sorted_token_ids[e_start + position_1] =
3208 p_sorted_token_ids[e_start + position_1] = i_token + 1;
3210 p_sorted_weights[e_start + position_1] =
3211 p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
3218 for(
index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i +=
kBlockSize)
3220#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3223 p_sorted_token_ids[i] = tokens;
3225 p_sorted_weights[i] =
static_cast<WeightType>(0.0);
3231#undef MOE_SORTING_MOCK_ID
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts)
Definition moe_sorting_kernel.hpp:1191
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t *buf, long_index_t buf_bytes, index_t gid)
Definition moe_sorting_kernel.hpp:1348
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
Definition moe_sorting_kernel.hpp:1157
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number< wave_size_ >={})
Definition moe_sorting_kernel.hpp:1205
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(void *buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
Definition moe_sorting_kernel.hpp:1359
CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t &thread_data)
Definition moe_sorting_kernel.hpp:1236
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
Definition moe_sorting_kernel.hpp:1198
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
Definition moe_sorting_kernel.hpp:2677
CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_, index_t, index_t topk_)
Definition moe_sorting_kernel.hpp:1166
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens, index_t num_experts, index_t topk)
Definition moe_sorting_kernel.hpp:1182
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
Definition arch.hpp:328
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
uint8_t uint8x16_t
Definition vector_type.hpp:202
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:180
CK_TILE_DEVICE void s_waitcnt()
Definition arch.hpp:241
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_, int dispatch_policy_)
Definition moe_sorting_kernel.hpp:1409
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_)
Definition moe_sorting_kernel.hpp:1394
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition arch.hpp:121
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:1380
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:133
int32_t index_t
Definition ck.hpp:299
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_)
Definition reference_moe_sorting.hpp:11
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition moe_sorting_kernel.hpp:1450
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1456
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1452
index_t tokens
Definition moe_sorting_kernel.hpp:1453
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1451
index_t mesh_byte_size
Definition moe_sorting_kernel.hpp:1457
index_t num_experts
Definition moe_sorting_kernel.hpp:1455
Definition moe_sorting_kernel.hpp:1442
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1460
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1447
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1444
static CK_TILE_HOST constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1489
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1445
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1486
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1491
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1443
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1472
static CK_TILE_HOST constexpr auto GridSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1484
Definition moe_sorting_kernel.hpp:189
void * p_ws
Definition moe_sorting_kernel.hpp:203
index_t tokens
Definition moe_sorting_kernel.hpp:206
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:219
void * p_moe_buf
Definition moe_sorting_kernel.hpp:202
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:199
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:190
const void * p_weights
Definition moe_sorting_kernel.hpp:191
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:196
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:197
index_t unit_size
Definition moe_sorting_kernel.hpp:207
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:193
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:198
index_t topk
Definition moe_sorting_kernel.hpp:209
index_t num_experts
Definition moe_sorting_kernel.hpp:208
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:194
Definition moe_sorting_kernel.hpp:240
void * p_moe_buf
Definition moe_sorting_kernel.hpp:249
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:248
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:260
index_t smem_rows
Definition moe_sorting_kernel.hpp:259
index_t tokens
Definition moe_sorting_kernel.hpp:250
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:246
index_t num_experts
Definition moe_sorting_kernel.hpp:251
mdiv expert_mdiv
Definition moe_sorting_kernel.hpp:262
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:256
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:245
index_t tokens_per_thread
Definition moe_sorting_kernel.hpp:258
const void * p_weights
Definition moe_sorting_kernel.hpp:242
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:243
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:244
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:261
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:247
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:241
Definition moe_sorting_kernel.hpp:738
CK_TILE_DEVICE simple_smem_indexer(index_t *smem_, index_t row_stride_)
Definition moe_sorting_kernel.hpp:743
CK_TILE_DEVICE index_t & operator()(index_t i_row, index_t i_col)
Definition moe_sorting_kernel.hpp:751
index_t * smem
Definition moe_sorting_kernel.hpp:739
CK_TILE_DEVICE simple_smem_indexer(index_t *smem_)
Definition moe_sorting_kernel.hpp:757
CK_TILE_DEVICE index_t & operator()(index_t idx)
Definition moe_sorting_kernel.hpp:759
index_t row_stride
Definition moe_sorting_kernel.hpp:740
CK_TILE_DEVICE const index_t & operator()(index_t idx) const
Definition moe_sorting_kernel.hpp:758
CK_TILE_DEVICE const index_t & operator()(index_t i_row, index_t i_col) const
Definition moe_sorting_kernel.hpp:747
Definition moe_sorting_kernel.hpp:226
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:229
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:236
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(void *buf, index_t row, index_t col, index_t elem_bytes) const
Definition moe_sorting_kernel.hpp:500
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:230
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens_per_thread, const index_t numel, const mdiv unit_size_mdiv, const mdiv topk_mdiv, void *smem) const
Definition moe_sorting_kernel.hpp:517
CK_TILE_DEVICE void moe_align_block_size_kernel_ex(const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, const IndexType *__restrict__ local_expert_mask, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens, const mdiv unit_size_mdiv, const mdiv topk_mdiv, const mdiv expert_mdiv, const index_t smem_rows, void *smem) const
Definition moe_sorting_kernel.hpp:763
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:278
__device__ void wave_cumsum(data_t &thread_data) const
Definition moe_sorting_kernel.hpp:354
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t *buf, long_index_t buf_bytes) const
Definition moe_sorting_kernel.hpp:490
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:266
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:232
static __device__ constexpr T wave_reduce(T local, F reduce_f, number< wave_size_ >={})
Definition moe_sorting_kernel.hpp:457
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
Definition moe_sorting_kernel.hpp:485
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:227
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:237
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:312
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:289
static CK_TILE_HOST constexpr auto GetSmemSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:300
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1090
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:234
Definition moe_sorting_kernel.hpp:1593
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1596
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1595
index_t tokens
Definition moe_sorting_kernel.hpp:1597
index_t num_experts
Definition moe_sorting_kernel.hpp:1599
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1600
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:1594
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:1601
Definition moe_sorting_kernel.hpp:1578
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1583
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1586
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1631
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1582
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1636
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1616
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1585
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1581
static CK_TILE_HOST constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1634
static CK_TILE_HOST constexpr auto GridSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1629
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1590
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1579
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1588
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1604
Definition moe_sorting_kernel.hpp:1714
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:1723
index_t num_experts
Definition moe_sorting_kernel.hpp:1725
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1720
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1717
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:1715
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:1724
index_t tokens
Definition moe_sorting_kernel.hpp:1718
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:1721
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1716
Definition moe_sorting_kernel.hpp:1700
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1703
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1704
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1768
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1728
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1763
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1701
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1705
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1711
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1707
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1759
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:1757
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1709
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1740
Definition moe_sorting_kernel.hpp:1907
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:1911
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1909
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1910
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:1908
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1912
Definition moe_sorting_kernel.hpp:1893
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1897
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1898
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1894
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1905
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1939
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1901
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1896
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1934
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:1929
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1900
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1915
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1931
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1903
Definition moe_sorting_kernel.hpp:2702
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2707
const void * p_weights
Definition moe_sorting_kernel.hpp:2703
index_t num_experts
Definition moe_sorting_kernel.hpp:2716
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:2729
index_t tokens
Definition moe_sorting_kernel.hpp:2715
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:2708
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2705
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2717
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:2709
void * p_moe_buf
Definition moe_sorting_kernel.hpp:2713
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2706
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:2711
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2704
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:2712
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:2719
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:2718
Definition moe_sorting_kernel.hpp:2688
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2692
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2801
static CK_TILE_HOST constexpr auto GetSmemSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2793
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2691
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2779
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2689
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2700
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2733
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2696
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2790
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2693
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2698
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2695
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:2767
Definition moe_sorting_kernel.hpp:2284
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2285
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:2296
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2287
index_t num_experts
Definition moe_sorting_kernel.hpp:2293
index_t tokens
Definition moe_sorting_kernel.hpp:2292
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2294
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:2289
void * p_moe_buf
Definition moe_sorting_kernel.hpp:2291
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2288
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:2290
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2286
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:2295
Definition moe_sorting_kernel.hpp:2270
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2273
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2339
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2271
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2278
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2349
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2359
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:2327
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2277
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2299
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2280
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2275
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2274
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:2352
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2282
Definition moe_sorting_kernel.hpp:2527
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:2532
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2538
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2529
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2534
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:2531
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2530
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:2539
const void * p_weights
Definition moe_sorting_kernel.hpp:2528
index_t tokens
Definition moe_sorting_kernel.hpp:2536
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2533
index_t num_experts
Definition moe_sorting_kernel.hpp:2537
Definition moe_sorting_kernel.hpp:2512
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2520
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2516
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2561
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2519
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2563
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2513
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2517
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2524
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:2566
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2522
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2515
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2571
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2542
Definition magic_div.hpp:186
CK_TILE_HOST_DEVICE void divmod(uint32_t dividend_, uint32_t "ient_, uint32_t &remainder_) const
Definition magic_div.hpp:218
uint32_t divisor
Definition magic_div.hpp:188
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition magic_div.hpp:212
Definition coordinate_transform.hpp:1392
Definition tile/core/utility/functional.hpp:43