12 template <
typename Problem>
17 return 16 /
sizeof(QDataType);
20 template <
typename Problem>
25 return 16 /
sizeof(KDataType);
28 template <
typename Problem>
33 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
35 constexpr index_t kBlockSize = Problem::kBlockSize;
36 constexpr index_t kNPerBlock = Problem::kN0;
37 constexpr index_t kKPerBlock = Problem::kN1;
38 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
41 if constexpr(total_pixels > 4)
48 return 16 /
sizeof(VDataType);
52 template <
typename Problem>
55 using DataType =
typename Problem::QDataType;
60 return 16 /
sizeof(DataType);
64 return 16 /
sizeof(DataType);
68 template <
typename Problem>
76 static_assert(Problem::kK0 % KPerThread == 0);
77 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
80 return make_tuple(start_pos, start_pos + KPerThread);
85 static_assert(Problem::kK0 % KPerThread == 0);
86 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
89 return make_tuple(start_pos, start_pos + KPerThread);
93 template <
typename Problem>
96 constexpr index_t kBlockSize = Problem::kBlockSize;
97 constexpr index_t kMPerBlock = Problem::kM0;
98 constexpr index_t kKPerBlock = Problem::kK0;
101 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
104 constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
116 template <
typename Problem>
119 using DataType =
typename Problem::KDataType;
124 return 16 /
sizeof(DataType);
128 return 16 /
sizeof(DataType);
132 template <
typename Problem>
140 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
143 return make_tuple(start_pos, start_pos + KPerThread);
148 constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
151 return make_tuple(start_pos, start_pos + KPerThread);
155 template <
typename Problem>
158 constexpr index_t kBlockSize = Problem::kBlockSize;
159 constexpr index_t kNPerBlock = Problem::kN0;
160 constexpr index_t kKPerBlock = Problem::kK0;
163 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
166 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
178 template <
typename Problem>
183 return 16 /
sizeof(VDataType);
186 template <
typename Problem>
192 constexpr index_t kBlockSize = Problem::kBlockSize;
193 constexpr index_t kNPerBlock = Problem::kN1;
194 constexpr index_t kKPerBlock = Problem::kN0;
196 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
199 constexpr index_t NPerThread = 16 /
sizeof(VDataType);
200 constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
203 constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
216 constexpr index_t KPerThread = 16 /
sizeof(VDataType);
217 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
220 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
233 template <
typename Problem,
bool IsRotaryCosSinForQ>
236 constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
248 template <
typename Problem,
bool IsRotaryCosSinForQ>
251 using DataType = std::conditional_t<IsRotaryCosSinForQ,
252 typename Problem::QDataType,
253 typename Problem::KDataType>;
257 constexpr index_t kBlockSize = Problem::kBlockSize;
261 constexpr index_t KPerThread = []() {
265 return 16 /
sizeof(DataType);
269 return 8 /
sizeof(DataType);
272 constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
275 constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
@ INTERLEAVED
Definition block_rotary_embedding.hpp:14
@ HALF_ROTATED
Definition block_rotary_embedding.hpp:15
@ NONE
Definition block_rotary_embedding.hpp:13
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE index_t get_thread_id()
Definition arch.hpp:117
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:11
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr auto MakeKnewDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:156
static CK_TILE_DEVICE auto GetKnewThreadRangeAlongK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:133
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:179
static CK_TILE_HOST_DEVICE constexpr auto GetQNumElemsPerRead()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:53
static CK_TILE_HOST_DEVICE constexpr auto MakeRotaryCosSinTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:249
static CK_TILE_HOST_DEVICE constexpr auto GetRotaryCosSinTileSize()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:234
static CK_TILE_DEVICE auto GetQThreadRangeAlongK()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:69
static CK_TILE_HOST_DEVICE constexpr auto MakeVnewDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:187
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto GetKnewNumElemsPerRead()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:117
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:29
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_fwd_appendkv_pipeline_default_policy.hpp:94
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192