block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp Source File

block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp Source File#

Composable Kernel: block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp Source File
block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
14template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
16{
34 using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
35
37
38 static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
39 static constexpr index_t kBlockSize = Problem::kBlockSize;
40
41 static constexpr index_t kM0 = BlockFmhaShape::kM0;
42 static constexpr index_t kN0 = BlockFmhaShape::kN0;
43 static constexpr index_t kK0 = BlockFmhaShape::kK0;
44 static constexpr index_t kK1 = BlockFmhaShape::kK1;
45 static constexpr index_t kK2 = BlockFmhaShape::kK2;
46 static constexpr index_t kK3 = BlockFmhaShape::kK3;
47 static constexpr index_t kK4 = BlockFmhaShape::kK4;
48 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
49 static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
50
51 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
52 static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
53 static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
54 static constexpr auto BiasEnum = Problem::BiasEnum;
55 static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
56 static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
57 static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
58 static_assert(!kUseTrLoad, "This pipeline does not use trload!");
59
60 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
61 // ... together with tensor distribution. tensor dist should able to overwrite this
62 static constexpr index_t kAlignmentQ =
63 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
64 static constexpr index_t kAlignmentK =
65 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
66 static constexpr index_t kAlignmentV =
67 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
68 static constexpr index_t kAlignmentOGrad =
69 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
70 static constexpr index_t kAlignmentQGrad = 1;
71 static constexpr index_t kAlignmentKGrad =
72 kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
73 static constexpr index_t kAlignmentVGrad =
74 kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
75 static constexpr index_t kAlignmentBias = 1;
76
77 static constexpr const char* name = "kr_ktr_vr_iglp";
78
80 {
81 return Policy::template GetSmemSize<Problem>();
82 }
83
84 template <typename QDramBlockWindowTmp,
85 typename KDramBlockWindowTmp,
86 typename VDramBlockWindowTmp,
87 typename BiasDramBlockWindowTmp,
88 typename RandValDramBlockWindowTmp,
89 typename OGradDramBlockWindowTmp,
90 typename LSEDramBlockWindowTmp,
91 typename DDramBlockWindowTmp,
92 typename QGradDramBlockWindowTmp,
93 typename BiasGradDramBlockWindowTmp,
94 typename PositionEncoding>
96 operator()(void* smem_ptr,
97 const QDramBlockWindowTmp& q_dram_block_window_tmp,
98 const KDramBlockWindowTmp& k_dram_block_window_tmp,
99 const VDramBlockWindowTmp& v_dram_block_window_tmp,
100 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
101 const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
102 const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
103 const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
104 const DDramBlockWindowTmp& d_dram_block_window_tmp,
105 const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
106 const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
107 FmhaMask mask,
108 PositionEncoding position_encoding,
109 float raw_scale,
110 float scale,
111 float rp_undrop,
112 float scale_rp_undrop,
113 FmhaDropout& dropout) const
114 {
115 static_assert(
116 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
117 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
118 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
119 std::is_same_v<OGradDataType,
121 std::is_same_v<LSEDataType,
123 std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
124 "wrong!");
125
126 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
127 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
128 kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
129 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
130 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
131 kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
132 kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
133 kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
134 kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
135 kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
136 kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
137 "wrong!");
138
139 // Block GEMM
140 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
141 constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
142 constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
143 constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
144 constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
145
146 // init VGrad & KGrad
147 auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
148 auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
149
150 // K, HBM ->LDS ->Reg
151 auto k_dram_window =
152 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
153 k_dram_block_window_tmp.get_window_lengths(),
154 k_dram_block_window_tmp.get_window_origin(),
155 Policy::template MakeKDramTileDistribution<Problem>());
156
157 const auto k_origin = k_dram_window.get_window_origin();
158 // Early termination
159 const auto [seqlen_q_start, seqlen_q_end] =
160 mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
161
162 const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
163
164 // check early exit if masked and no work to do.
165 if constexpr(FmhaMask::IsMasking)
166 {
167 if(num_total_loop <= 0)
168 {
169 // Note: here dk_acc&dv_acc are all cleard, return it
170 // Note: v loaded but no fence, ignore it.
171 return make_tuple(dk_acc, dv_acc);
172 }
173 }
174 KDataType* k_lds_ptr =
175 static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
177 k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
178
179 auto k_lds_write_window =
181
182 auto k_lds_read_window =
183 make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
185 k_lds_write_window.get_window_origin(),
186 Policy::template MakeKRegBlockDescriptor<Problem>());
187
189 Policy::template MakeKRegBlockDescriptor<Problem>());
190
191 //------------------------------------------------------------------
192 // V, HBM ->LDS ->Reg
193 auto v_dram_window =
194 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
195 v_dram_block_window_tmp.get_window_lengths(),
196 v_dram_block_window_tmp.get_window_origin(),
197 Policy::template MakeVDramTileDistribution<Problem>());
198
199 VDataType* v_lds_ptr =
200 static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
201
203 v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
204
205 auto v_lds_write_window =
207
208 auto v_lds_read_window =
209 make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
211 v_lds_write_window.get_window_origin(),
212 Policy::template MakeVRegBlockDescriptor<Problem>());
213
214 //------------------------------------------------------------------
215 // KT, Reg ->LDS ->Reg
216 auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
217 Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
218
219 KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
220 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
221
222 auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
223 kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
224
225 auto shuffled_k_lds_write_window = make_tile_window(
226 shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
227
229 kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
230
231 auto kt_lds_read_window =
232 make_tile_window(kt_lds_read,
234 {0, 0},
235 Policy::template MakeKTRegBlockDescriptor<Problem>());
236
237 //------------------------------------------------------------------
238 // Pre-Load KV into Registers
239 auto k_block_tile = load_tile(k_dram_window);
240 auto v_block_tile = load_tile(v_dram_window);
241
242 store_tile(k_lds_write_window, k_block_tile);
243 shuffle_tile(shuffled_k_block_tile, k_block_tile);
244 store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
245
247 k_reg_tensor = load_tile(k_lds_read_window);
249
250 auto kt_reg_tensor = load_tile(kt_lds_read_window);
251
252 store_tile(v_lds_write_window, v_block_tile);
253
255
256 auto v_reg_tensor = load_tile(v_lds_read_window);
257 //---------------------------- Loop Load in ----------------------------//
258 // Q: HBM ->Reg ->LDS
259 auto q_dram_window =
260 make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
261 q_dram_block_window_tmp.get_window_lengths(),
262 {seqlen_q_start, 0},
263 Policy::template MakeQDramTileDistribution<Problem>());
264
265 QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
266 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
267 Policy::template GetSmemSizeOGrad<Problem>() +
268 Policy::template GetSmemSizeOGradT<Problem>()));
269
271 q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
272
273 auto q_lds_window =
275
276 auto q_lds_read_window =
277 make_tile_window(q_lds_window.get_bottom_tensor_view(),
279 q_lds_window.get_window_origin(),
280 Policy::template MakeQRegSliceBlockDescriptor<Problem>());
281
283 Policy::template MakePTRegSliceBlockDescriptor<Problem>());
284 // QT: Reg -> Reg-> LDS
285 auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
286 Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
287
288 QDataType* qt_lds_ptr =
289 static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
290
291 auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
292 qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
293
294 auto shuffled_q_lds_write_window = make_tile_window(
295 shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
296
298 qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
299
300 auto qt_lds_read_window =
301 make_tile_window(qt_lds_read,
303 {0, 0},
304 Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
305
306 // dO: HBM ->Reg ->LDS
307 auto do_dram_window =
308 make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
309 do_dram_block_window_tmp.get_window_lengths(),
310 {seqlen_q_start, 0},
311 Policy::template MakeOGradDramTileDistribution<Problem>());
312
313 OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
314 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
315
317 do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
318
319 auto do_lds_window =
321
322 auto do_lds_read_window =
323 make_tile_window(do_lds_window.get_bottom_tensor_view(),
325 do_lds_window.get_window_origin(),
326 Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
327 // dOT: Reg ->Reg ->LDS
328 auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
329 Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
330
331 OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
332 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
333 Policy::template GetSmemSizeOGrad<Problem>()));
334
335 auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
336 dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
337
338 auto shuffled_do_lds_write_window = make_tile_window(
339 shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
340
342 dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
343
344 auto dot_lds_read_window =
345 make_tile_window(dot_read_lds,
347 {0, 0},
348 Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
349
350 // dS: Reg -> Reg -> LDS
351 GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
352 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
353 Policy::template GetSmemSizeOGrad<Problem>() +
354 Policy::template GetSmemSizeOGradT<Problem>() +
355 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
356 Policy::template GetSmemSizeD<Problem>()));
357
359 ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
360
361 auto ds_lds_window =
362 make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
363
364 auto ds_lds_read_window =
365 make_tile_window(ds_lds_window.get_bottom_tensor_view(),
367 ds_lds_window.get_window_origin(),
368 Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
369
371 Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
372 // Bias: HBM ->Reg ->Reg ->LDS
373 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
374
375 auto bias_dram_window =
376 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
377 bias_dram_block_window_tmp.get_window_lengths(),
378 {seqlen_q_start, bias_origin.at(number<1>{})},
379 Policy::template MakeBiasTileDistribution<Problem>());
380
381 BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
382 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
383 Policy::template GetSmemSizeOGrad<Problem>() +
384 Policy::template GetSmemSizeOGradT<Problem>() +
385 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
386 Policy::template GetSmemSizeD<Problem>()));
387
389 bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
390
391 auto bias_lds_write_window =
392 make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
393
394 auto bias_s_lds_read_window =
395 make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
396 bias_lds_write_window.get_window_lengths(),
397 bias_lds_write_window.get_window_origin(),
398 Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
399
400 static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
401 "BiasDataType and BiasGradDataType should be the same!");
402
403 // LSE: HBM -> LDS ->Reg
404 auto lse_dram_window = make_tile_window(
405 lse_dram_block_window_tmp.get_bottom_tensor_view(),
406 lse_dram_block_window_tmp.get_window_lengths(),
407 {seqlen_q_start},
408 Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
409
410 LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
411 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
412 Policy::template GetSmemSizeOGrad<Problem>() +
413 Policy::template GetSmemSizeOGradT<Problem>() +
414 Policy::template GetSmemSizeQ<Problem>()));
415
417 lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
418
419 auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
420
421 auto lse_lds_read_window = make_tile_window(
422 lse_lds,
424 {0},
425 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
426
427 // D: HBM ->Reg
428 auto d_dram_window = make_tile_window(
429 d_dram_block_window_tmp.get_bottom_tensor_view(),
430 d_dram_block_window_tmp.get_window_lengths(),
431 {seqlen_q_start},
432 Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
433
434 DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
435 static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
436 Policy::template GetSmemSizeOGrad<Problem>() +
437 Policy::template GetSmemSizeOGradT<Problem>() +
438 Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
439
441 d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
442
443 auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
444
445 auto d_lds_read_window = make_tile_window(
446 d_lds,
448 {0},
449 Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
450
451 // RandVal: HBM ->Reg
452 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
453 randval_dram_block_window_tmp, seqlen_q_start);
454
455 // BiasGrad
456 // Reg ->LDS ->Reg ->HBM
457 const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
458
459 auto dbias_dram_window =
460 make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
461 dbias_dram_block_window_tmp.get_window_lengths(),
462 {seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
463
464 auto dbias_lds_read_window =
465 make_tile_window(bias_lds,
467 {0, 0},
468 Policy::template MakeShuffledBiasTileDistribution<Problem>());
469
470 // ----------------------------Loop write out------------------------------//
471 auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
472 dq_dram_block_window_tmp.get_window_lengths(),
473 {seqlen_q_start, 0});
474
475 using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
476 using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
477 using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
478
479 index_t i_total_loops = 0;
480 index_t seqlen_q_step = seqlen_q_start;
481 static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
482 static_assert(kM0 == kK1, "kM0 should equal to kK1");
483 static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
484 static_assert(kM0 == kK3, "kM0 should equal to kK3");
485 constexpr index_t k4_loops = kN0 / kK4;
486
487 /*
488 * Prefetch Q, LSE, dO, D
489 */
490 auto q_block_tile = load_tile(q_dram_window);
491 move_tile_window(q_dram_window, {kM0, 0});
492 auto lse_block_tile = load_tile(lse_dram_window);
493 move_tile_window(lse_dram_window, {kM0});
494
495 auto do_block_tile = load_tile(do_dram_window);
496 move_tile_window(do_dram_window, {kM0, 0});
497
498 auto d_block_tile = load_tile(d_dram_window);
499 move_tile_window(d_dram_window, {kM0});
500
501 /*
502 * Store prefetched data into LDS
503 */
505 store_tile(q_lds_window, q_block_tile);
506 shuffle_tile(shuffled_q_block_tile, q_block_tile);
507 store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
508
509 store_tile(lse_lds_write_window, lse_block_tile);
510
511 store_tile(do_lds_window, do_block_tile);
512 shuffle_tile(shuffled_do_block_tile, do_block_tile);
513 store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
514
515 store_tile(d_lds_write_window, d_block_tile);
517
518 /*
519 * Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
520 */
521
522 auto q_reg_tensor = load_tile(q_lds_read_window);
523 auto lse = load_tile(lse_lds_read_window);
524 auto do_reg_tensor = load_tile(do_lds_read_window);
525 auto d = load_tile(d_lds_read_window);
526
527 clear_tile(dv_acc);
528 clear_tile(dk_acc);
529
530 __builtin_amdgcn_sched_barrier(0);
531 // Hot loop
532 while(i_total_loops < (num_total_loop - 1))
533 {
534 // STAGE 1, Q@K Gemm0
535 auto s_acc = SPBlockTileType{};
536
537 q_block_tile = load_tile(q_dram_window);
538 move_tile_window(q_dram_window, {kM0, 0});
539
540 lse_block_tile = load_tile(lse_dram_window);
541 move_tile_window(lse_dram_window, {kM0});
542
543 do_block_tile = load_tile(do_dram_window);
544 move_tile_window(do_dram_window, {kM0, 0});
545
546 d_block_tile = load_tile(d_dram_window);
547 move_tile_window(d_dram_window, {kM0});
548
549 s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
550
551 auto dot_reg_tensor = load_tile(dot_lds_read_window);
552
553 HotLoopScheduler::template GemmStagedScheduler<0>();
554 __builtin_amdgcn_sched_barrier(0);
555 // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
556 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
557 {
558 const auto bias_tile = load_tile(bias_dram_window);
559 auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
560 Policy::template MakeShuffledBiasTileDistribution<Problem>());
561 shuffle_tile(shuffled_bias_tile, bias_tile);
562 // SGrad and Bias use the same address in LDS, finish loading ds on the previous
563 // iteration to reuse LDS.
565 store_tile(bias_lds_write_window, shuffled_bias_tile);
567 auto bias_s_tile = load_tile(bias_s_lds_read_window);
569 [&](auto& x, const auto& y) {
571 },
572 s_acc,
573 bias_s_tile);
574 move_tile_window(bias_dram_window, {kM0, 0});
575 __builtin_amdgcn_sched_barrier(0);
576 }
577 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
578 {
579 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
580 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
581 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
582 const auto tile_idx = get_x_indices_from_distributed_indices(
583 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
584
585 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
586 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
587 constexpr auto i_j_idx = make_tuple(idx0, idx1);
588
589 s_acc(i_j_idx) *= scale;
590 position_encoding.update(s_acc(i_j_idx), row, col);
591 });
592 });
593 }
594
595 {
596 bool need_perpixel_check = mask.IsEdgeTile(
597 seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
598 if(need_perpixel_check)
599 {
600 set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
601 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
602 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
603 return mask.IsOutOfBound(row, col);
604 });
605 }
606 }
607
608 static const auto get_validated_lse = [](LSEDataType raw_lse) {
609 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
610 FmhaMask::IsMasking)
611 {
612 return raw_lse == -numeric<LSEDataType>::infinity()
614 : raw_lse;
615 }
616 else
617 {
618 return raw_lse;
619 }
620 };
621
622 auto p = SPBlockTileType{};
623 constexpr auto p_spans = decltype(p)::get_distributed_spans();
624 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
625 constexpr auto i_idx = make_tuple(idx0);
626 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
627
628 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
629 constexpr auto i_j_idx = make_tuple(idx0, idx1);
630
631 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
633 {
634 p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
635 }
636 else
637 {
638 p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
639 }
640 });
641 });
642
643 if constexpr(FmhaDropout::IsDropout)
644 {
645 dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
646 seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
647 }
648 const auto p_gemm = [&]() {
649 if constexpr(FmhaDropout::IsDropout)
650 {
651 return tile_elementwise_in(
652 [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
653 p);
654 }
655 else
656 {
657 return cast_tile<GemmDataType>(p);
658 }
659 }();
660
661 // STAGE 3, P^T@OGrad^T Gemm1
662 Policy::template PTFromGemm0CToGemm1A<Problem>(pt_reg_tensor, p_gemm);
663 gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
664
665 auto qt_reg_tensor = load_tile(qt_lds_read_window);
666
667 HotLoopScheduler::template GemmStagedScheduler<1>();
668 __builtin_amdgcn_sched_barrier(0);
669 // STAGE 4, OGrad@V Gemm2
670 auto dp_acc = SPGradBlockTileType{};
671
672 dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
673
675
676 store_tile(q_lds_window, q_block_tile);
677 shuffle_tile(shuffled_q_block_tile, q_block_tile);
678 store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
679
680 store_tile(lse_lds_write_window, lse_block_tile);
681
682 store_tile(do_lds_window, do_block_tile);
683 shuffle_tile(shuffled_do_block_tile, do_block_tile);
684 store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
685
686 store_tile(d_lds_write_window, d_block_tile);
687
688 HotLoopScheduler::template GemmStagedScheduler<2>();
689 __builtin_amdgcn_sched_barrier(0);
690 // STAGE 5, P^T(PGrad^T - D)
691 auto ds = SPGradBlockTileType{};
692 constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
693 sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
694 constexpr auto i_idx = make_tuple(idx0);
695 sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
696 constexpr auto i_j_idx = make_tuple(idx0, idx1);
697 bool undrop_flag = p[i_j_idx] >= 0;
698 ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
699 ? (dp_acc[i_j_idx] - d[i_idx])
700 : d[i_idx]);
701 });
702 });
703
704 if constexpr(kHasBiasGrad)
705 {
706 const auto dbias = [&]() {
707 if constexpr(FmhaDropout::IsDropout)
708 {
709 return tile_elementwise_in(
710 [&rp_undrop](const auto& x) {
711 return type_convert<BiasGradDataType>(x * rp_undrop);
712 },
713 ds);
714 }
715 else
716 {
718 }
719 }();
720 store_tile(bias_lds_write_window, dbias);
722 auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
724 Policy::template MakeBiasTileDistribution<Problem>());
725 shuffle_tile(dbias_tile, shuffled_dbias_tile);
726 store_tile(dbias_dram_window, dbias_tile);
727 move_tile_window(dbias_dram_window, {kM0, 0});
728 __builtin_amdgcn_sched_barrier(0);
729 }
730
731 // STAGE 6, SGrad^T@Q^T Gemm3
732 const auto ds_gemm = cast_tile<GemmDataType>(ds);
733
734 Policy::template SGradTFromGemm2CToGemm3A<Problem>(dst_reg_tensor, ds_gemm);
735
736 gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
737
738 if constexpr(kHasBiasGrad)
739 {
740 // SGrad and BiasGrad use the same address in LDS.
742 }
743 store_tile(ds_lds_window, ds_gemm);
744
746
747 auto ds_reg_tensor = load_tile(ds_lds_read_window);
748 auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
749 move_tile_window(ds_lds_read_window, {0, kK4});
750 q_reg_tensor = load_tile(q_lds_read_window);
751 lse = load_tile(lse_lds_read_window);
752
753 HotLoopScheduler::template GemmStagedScheduler<3>();
754 __builtin_amdgcn_sched_barrier(0);
755 // STAGE7 SGrad@K^T Gemm4
756 auto dq_acc = QGradBlockTileType{};
757 clear_tile(dq_acc);
758
759 static_for<0, k4_loops, 1>{}([&](auto i_k4) {
760 if constexpr(i_k4 < k4_loops - 1)
761 {
762 ds_reg_tensor_next = load_tile(ds_lds_read_window);
763 move_tile_window(ds_lds_read_window, {0, kK4});
764 }
765 auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
767 sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
768 gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
769
770 if constexpr(i_k4 < k4_loops - 1)
771 {
772 ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
773 }
774 });
775 move_tile_window(ds_lds_read_window, {0, -kN0});
776
777 do_reg_tensor = load_tile(do_lds_read_window);
778 d = load_tile(d_lds_read_window);
779
780 HotLoopScheduler::template GemmStagedScheduler<4>();
781
782 // QGrad Scale
783 if constexpr(FmhaDropout::IsDropout)
784 {
785 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
786 dq_acc);
787 }
788 else
789 {
790 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
791 }
792 if constexpr(kIsDeterministic)
793 {
794 store_tile(dq_dram_window, dq_acc);
795 }
796 else
797 {
798 update_tile(dq_dram_window, dq_acc);
799 }
800 move_tile_window(dq_dram_window, {kM0, 0});
801
802 i_total_loops += 1;
803 seqlen_q_step += kM0;
804 }
805 __builtin_amdgcn_sched_barrier(0);
806
807 // Tail
808 auto s_acc = SPBlockTileType{};
809
810 // STAGE 1, Q@K Gemm0
811 s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
812
813 // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
814 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
815 {
816 const auto bias_tile = load_tile(bias_dram_window);
817 auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
818 Policy::template MakeShuffledBiasTileDistribution<Problem>());
819 shuffle_tile(shuffled_bias_tile, bias_tile);
820 // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to
821 // reuse LDS.
823 store_tile(bias_lds_write_window, shuffled_bias_tile);
825 auto bias_s_tile = load_tile(bias_s_lds_read_window);
827 [&](auto& x, const auto& y) {
829 },
830 s_acc,
831 bias_s_tile);
832 __builtin_amdgcn_sched_barrier(0);
833 }
834 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
835 {
836 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
837 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
838 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
839 const auto tile_idx = get_x_indices_from_distributed_indices(
840 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
841
842 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
843 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
844 constexpr auto i_j_idx = make_tuple(idx0, idx1);
845
846 s_acc(i_j_idx) *= scale;
847 position_encoding.update(s_acc(i_j_idx), row, col);
848 });
849 });
850 }
851
852 {
853 bool need_perpixel_check = mask.IsEdgeTile(
854 seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
855 if(need_perpixel_check)
856 {
857 set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
858 const auto row = seqlen_q_step + tile_idx.at(number<0>{});
859 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
860 return mask.IsOutOfBound(row, col);
861 });
862 }
863 }
864
865 static const auto get_validated_lse = [](LSEDataType raw_lse) {
866 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
867 FmhaMask::IsMasking)
868 {
870 : raw_lse;
871 }
872 else
873 {
874 return raw_lse;
875 }
876 };
877
878 auto p = SPBlockTileType{};
879 constexpr auto p_spans = decltype(p)::get_distributed_spans();
880 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
881 constexpr auto i_idx = make_tuple(idx0);
882 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
883
884 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
885 constexpr auto i_j_idx = make_tuple(idx0, idx1);
886 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
888 {
889 p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
890 }
891 else
892 {
893 p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
894 }
895 });
896 });
897
898 if constexpr(FmhaDropout::IsDropout)
899 {
900 dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
901 seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
902 }
903
904 // STAGE 3, P^T@OGrad^T Gemm1
905 const auto p_gemm = [&]() {
906 if constexpr(FmhaDropout::IsDropout)
907 {
908 return tile_elementwise_in(
909 [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, p);
910 }
911 else
912 {
913 return cast_tile<GemmDataType>(p);
914 }
915 }();
916
917 Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(p_gemm)>(
918 pt_reg_tensor, p_gemm);
919 auto dot_reg_tensor = load_tile(dot_lds_read_window);
920 gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
921
922 HotLoopScheduler::template GemmStagedScheduler<1>();
923 __builtin_amdgcn_sched_barrier(0);
924
925 // STAGE 4, OGrad@V Gemm2
926 auto dp_acc = SPGradBlockTileType{};
927
928 auto qt_reg_tensor = load_tile(qt_lds_read_window);
929
930 dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
931
932 HotLoopScheduler::template GemmStagedScheduler<2>();
933 __builtin_amdgcn_sched_barrier(0);
934
935 // STAGE 5, P^T(PGrad^T - D)
936 auto ds = SPGradBlockTileType{};
937 constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
938 sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
939 constexpr auto i_idx = make_tuple(idx0);
940 sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
941 constexpr auto i_j_idx = make_tuple(idx0, idx1);
942 bool undrop_flag = p[i_j_idx] >= 0;
943 ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
944 ? (dp_acc[i_j_idx] - d[i_idx])
945 : d[i_idx]);
946 });
947 });
948
949 if constexpr(kHasBiasGrad)
950 {
951 const auto dbias = [&]() {
952 if constexpr(FmhaDropout::IsDropout)
953 {
954 return tile_elementwise_in(
955 [&rp_undrop](const auto& x) {
956 return type_convert<BiasGradDataType>(x * rp_undrop);
957 },
958 ds);
959 }
960 else
961 {
963 }
964 }();
965 // Finish loading bias_s to reuse LDS.
967 store_tile(bias_lds_write_window, dbias);
969 auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
971 Policy::template MakeBiasTileDistribution<Problem>());
972 shuffle_tile(dbias_tile, shuffled_dbias_tile);
973 store_tile(dbias_dram_window, dbias_tile);
974 __builtin_amdgcn_sched_barrier(0);
975 }
976
977 // STAGE 6, SGrad^T@Q^T Gemm3
978 const auto ds_gemm = cast_tile<GemmDataType>(ds);
979
980 Policy::template SGradTFromGemm2CToGemm3A<Problem,
981 decltype(dst_reg_tensor),
982 decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
983
984 gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
985
986 // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when
987 // bias is not used, loading ds in the hot loop to reuse LDS.
989 store_tile(ds_lds_window, ds_gemm);
990
992
993 auto ds_reg_tensor = load_tile(ds_lds_read_window);
994 auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
995 move_tile_window(ds_lds_read_window, {0, kK4});
996
997 HotLoopScheduler::template GemmStagedScheduler<3>();
998 __builtin_amdgcn_sched_barrier(0);
999 // STAGE 7, SGrad@K^T Gemm4
1000 auto dq_acc = QGradBlockTileType{};
1001 clear_tile(dq_acc);
1002
1003 static_for<0, k4_loops, 1>{}([&](auto i_k4) {
1004 if constexpr(i_k4 < k4_loops - 1)
1005 {
1006 ds_reg_tensor_next = load_tile(ds_lds_read_window);
1007 move_tile_window(ds_lds_read_window, {0, kK4});
1008 }
1009 auto kt_reg_tensor_slice = get_slice_tile(
1010 kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
1011
1012 gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
1013 if constexpr(i_k4 < k4_loops - 1)
1014 {
1015 ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
1016 }
1017 });
1018
1019 HotLoopScheduler::template GemmStagedScheduler<4>();
1020 __builtin_amdgcn_sched_barrier(0);
1021
1022 // Results Scale
1023 if constexpr(FmhaDropout::IsDropout)
1024 {
1025 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
1026 dq_acc);
1027 tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
1028 dk_acc);
1029 tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
1030 }
1031 else
1032 {
1033 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
1034 tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
1035 }
1036
1037 if constexpr(kIsDeterministic)
1038 {
1039 store_tile(dq_dram_window, dq_acc);
1040 }
1041 else
1042 {
1043 update_tile(dq_dram_window, dq_acc);
1044 }
1045
1046 return make_tuple(dk_acc, dv_acc);
1047 }
1048};
1049
1050} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:16
static constexpr index_t kBlockPerCu
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:38
static constexpr index_t kAlignmentK
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:64
remove_cvref_t< typename Problem::DDataType > DDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:24
static constexpr index_t kAlignmentOGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:68
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:26
static constexpr index_t kK3
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:46
typename Policy::template HotLoopScheduler< Problem > HotLoopScheduler
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:34
static constexpr index_t kAlignmentVGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:73
CK_TILE_HOST_DEVICE auto operator()(void *smem_ptr, const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const RandValDramBlockWindowTmp &randval_dram_block_window_tmp, const OGradDramBlockWindowTmp &do_dram_block_window_tmp, const LSEDramBlockWindowTmp &lse_dram_block_window_tmp, const DDramBlockWindowTmp &d_dram_block_window_tmp, const QGradDramBlockWindowTmp &dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp &dbias_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float raw_scale, float scale, float rp_undrop, float scale_rp_undrop, FmhaDropout &dropout) const
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:96
static constexpr index_t kAlignmentV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:66
static constexpr index_t kPadHeadDimV
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:53
static constexpr index_t kAlignmentKGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:71
remove_cvref_t< typename Problem::BiasGradDataType > BiasGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:31
static constexpr index_t kN0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:42
remove_cvref_t< typename Problem::QGradDataType > QGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:28
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:32
static constexpr bool kIsDeterministic
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:56
static constexpr index_t kAlignmentBias
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:75
remove_cvref_t< typename Problem::FmhaDropout > FmhaDropout
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:33
remove_cvref_t< typename Problem::KGradDataType > KGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:29
static constexpr index_t kPadHeadDimQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:52
static constexpr index_t kVHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:49
static constexpr const char * name
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:77
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:18
static constexpr index_t kAlignmentQGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:70
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:23
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:21
static constexpr bool kHasBiasGrad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:55
static constexpr auto BiasEnum
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:54
static constexpr index_t kM0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:41
static constexpr bool kUseTrLoad
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:57
remove_cvref_t< typename Problem::VGradDataType > VGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:30
static constexpr index_t kBlockSize
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:39
static constexpr index_t kAlignmentQ
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:62
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:25
static constexpr index_t kK1
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:44
remove_cvref_t< typename Problem::GemmDataType > GemmDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:20
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:19
static constexpr bool kIsGroupMode
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:51
remove_cvref_t< typename Problem::OGradDataType > OGradDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:27
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:79
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:17
static constexpr index_t kK4
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:47
static constexpr index_t kK0
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:43
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:36
static constexpr index_t kQKHeaddim
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:48
static constexpr index_t kK2
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:45
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp:22
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43