block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp Source File

block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp Source File
block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
12 : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
13 /* AsyncCopy = */ false,
14 /* NumPrefetchK = */ -1,
15 /* NumPrefetchV = */ 2>
16{
17 static constexpr index_t NumPrefetchV = 2;
18
19 template <typename Problem>
21 {
22 return Problem::BlockFmhaShape::kM0 <= 64;
23 };
24
25 template <typename Problem>
26 CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
27 {
28 return 2;
29 }
30
31 template <typename Problem>
32 CK_TILE_DEVICE static constexpr auto GetNumPrefetchV()
33 {
35
36 constexpr index_t kN0 = BlockFmhaShape::kN0;
37 constexpr index_t kK1 = BlockFmhaShape::kK1;
38
39 constexpr index_t k1_loops = kN0 / kK1;
40
41 return min(NumPrefetchV, k1_loops);
42 }
43
44 template <typename Problem>
46 {
47 return 2;
48 };
49
50 template <typename Problem>
52 {
54
55 return BlockGemm::template MakeABlockTileDistribution<
56 Problem::BlockFmhaShape::kM0,
57 Problem::BlockFmhaShape::kQKHeaddim>();
58 }
59
60 template <typename Problem>
61 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
62 {
64 return 8 / sizeof(KDataType);
65 }
66
67 template <typename Problem>
69 {
70 constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
71 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
72 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
73 constexpr index_t kKPack = GetSmemKPackK<Problem>();
74 constexpr index_t kKVector = GetAlignmentK<Problem>();
75
76 static_assert(kKVector % kKPack == 0);
77
78 constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
80 number<kKPerBlock / kKVector>{},
81 number<kKVector / kKPack>{},
88 number<1>{}),
90 number<1>{});
91
92 constexpr auto k_lds_block_desc = transform_tensor_descriptor(
93 k_lds_block_desc_0,
97 number<kKVector / kKPack>{},
98 number<kKPack>{}))),
101
102 return k_lds_block_desc;
103 }
104
105 template <typename Problem>
107 {
109
110 constexpr index_t kBlockSize = Problem::kBlockSize;
111 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
112 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
113
114 constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
115
116 constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
117 static_assert(0 < ElemPerThread);
118 constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
119
120 constexpr index_t KPerThread = kMaxVecLoad;
121 constexpr index_t KThreads = kKPerBlock / KPerThread;
122 constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
123 constexpr index_t NumWarps = kBlockSize / get_warp_size();
124 constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
125
133 sequence<0, 1>>{});
134 }
135
136 template <typename Problem>
138 {
140
141 constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
142
143 constexpr index_t Banks = get_n_lds_banks();
144 constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
145 constexpr index_t kKPack = GetSmemKPackV<Problem>();
146 static_assert(PixelsPerRow % kKPack == 0);
147 constexpr index_t NPerRow = PixelsPerRow / kKPack;
148 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
149 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
150 static_assert(kNPerBlock % NPerRow == 0);
151 static_assert(kKPerBlock % kKPack == 0);
152
153 constexpr index_t VSingleSmemElementSpaceSize =
154 (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
155
156 constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
158 number<kKPerBlock / kKPack>{},
159 number<kNPerBlock / NPerRow>{},
163 number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
166 number<1>{}),
168 number<1>{});
169
170 constexpr auto v_lds_block_desc = transform_tensor_descriptor(
171 v_lds_block_desc_0,
174 number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
178
179 return v_lds_block_desc;
180 }
181
182 template <typename Problem>
184 {
186
187 constexpr index_t kBlockSize = Problem::kBlockSize;
188 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
189 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
190
191 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
192 {
193 constexpr index_t N1 = GetAlignmentV<Problem>();
194 constexpr index_t N0 = kNPerBlock / N1; // P
195
196 constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
197 static_assert(ElemPerThread % N1 == 0);
198 constexpr index_t K3 = ElemPerThread / N1;
199 constexpr index_t kKPack = GetSmemKPackV<Problem>();
200 static_assert(kKPack % K3 == 0);
201 constexpr index_t K2 = kKPack / K3;
202 if constexpr(get_warp_size() % (K2 * N0) == 0)
203 {
204 constexpr index_t K1 = get_warp_size() / (K2 * N0);
205 constexpr index_t K0 = kBlockSize / get_warp_size();
206 static_assert(kKPerBlock == K0 * K1 * K2 * K3);
213 sequence<3, 1>>{});
214 }
215 else
216 {
217 constexpr index_t K1 = (K2 * N0) / get_warp_size();
218 constexpr index_t K2_m = K2 / K1;
219 constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
220 static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
227 sequence<3, 1>>{});
228 }
229 }
230 else
231 {
232 constexpr index_t K1 = GetAlignmentV<Problem>();
233 constexpr index_t K0 = kKPerBlock / K1;
234 constexpr index_t N2 = get_warp_size() / K0;
235 constexpr index_t N1 = kBlockSize / get_warp_size();
236 static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
237 static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
238 constexpr index_t N0 = kNPerBlock / (N2 * N1);
239 static_assert(N0 != 0);
240
247 sequence<0, 1>>{});
248 }
249 }
250
251 template <typename Problem>
252 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
253 {
254 using GemmProblem =
255 BlockGemmProblem<typename Problem::QDataType,
256 typename Problem::KDataType,
257 typename Problem::SaccDataType,
258 Problem::kNumGemm0Warps * get_warp_size(),
259 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
260 Problem::BlockFmhaShape::kN0,
261 Problem::BlockFmhaShape::kK0>,
262 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
263 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
264
265 constexpr auto warp_gemm = []() {
266 if constexpr(get_warp_size() == 64 &&
267 std::is_same_v<typename Problem::QDataType, fp8_t> &&
268 std::is_same_v<typename Problem::KDataType, fp8_t> &&
269 std::is_same_v<typename Problem::SaccDataType, float>)
270 {
271 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32);
272 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}) == 32);
273 static_assert(Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}) == 32);
274
275 // TODO: hard coded here. Otherwise, it produces incorrect results
276 constexpr index_t swizzle_factor = 4;
278 swizzle_factor>{};
279 }
280 else
281 {
282 constexpr bool SwizzleA =
283 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 32;
284 return WarpGemmDispatcher<typename Problem::QDataType,
285 typename Problem::KDataType,
286 typename Problem::SaccDataType,
287 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
288 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
289 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
290 true, // TransposeC
291 SwizzleA>{};
292 }
293 }();
294
295 using BlockGemmPolicy =
296 BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QDataType,
297 typename Problem::KDataType,
298 typename Problem::SaccDataType,
299 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
300 decltype(warp_gemm)>;
301
302 if constexpr(1 < Problem::kNumGemm0Warps)
304 else
306 }
307
308 // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first
309 // k_lds bufffer
310 template <typename Problem>
312 {
313 constexpr index_t single_k_lds_buffer_size =
315 constexpr index_t single_v_lds_buffer_size =
317
318 if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size)
319 return 0;
320 else
321 return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
322 };
323
324 template <typename Problem>
326 {
328
329 constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
330 constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
331 constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
332
333 constexpr index_t last_v_lds_buffer_offset =
334 MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
335 ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType);
336
337 constexpr index_t first_k_lds_buffer_size =
338 MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
339 sizeof(typename Problem::KDataType);
340
341 return GetExclusiveKLdsBytes<Problem>() + last_v_lds_buffer_offset <
342 first_k_lds_buffer_size;
343 };
344
345 template <typename Problem>
347 {
348 return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
349 sizeof(typename Problem::KDataType);
350 }
351
352 template <typename Problem>
354 {
355 return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
356 sizeof(typename Problem::VDataType);
357 }
358
359 template <typename Problem>
361 {
362 // assume V can reuse the other shared memory by K except the first
363 // assume Dropout can reuse the shared memory by V
367 }
368};
369
370} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
WarpGemmImpl< WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8< WGAttrCtlEnum::Default_ >, 2, swizzle_factor > > WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution
Definition warp_gemm.hpp:394
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:106
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeK()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:346
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t IsPreloadWholeNextIterationK()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:20
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetNumVLdsBuffers()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:45
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:325
static constexpr index_t NumPrefetchV
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:17
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeV()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:353
static CK_TILE_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:183
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:137
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:61
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:68
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:252
static CK_TILE_DEVICE constexpr auto GetNumKLdsBuffers()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetExclusiveKLdsBytes()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:311
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:360
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:51
static CK_TILE_DEVICE constexpr auto GetNumPrefetchV()
Definition block_fmha_pipeline_qr_ks_vs_whole_k_prefetch_default_policy.hpp:32
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:346
static CK_TILE_HOST_DEVICE constexpr std::enable_if_t< std::is_convertible_v< decltype(Problem::kHasDropout), bool >, ck_tile::index_t > GetSmemSizeDropout(int)
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:687
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:388
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:373
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
Definition block_gemm_areg_bsmem_creg_v2_custom_policy.hpp:16
Definition block_gemm_areg_bsmem_creg_v2.hpp:16
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192