moe_flatmm_pipeline_agmem_bgmem_creg.hpp Source File

moe_flatmm_pipeline_agmem_bgmem_creg.hpp Source File#

Composable Kernel: moe_flatmm_pipeline_agmem_bgmem_creg.hpp Source File
moe_flatmm_pipeline_agmem_bgmem_creg.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9#include <cwchar>
10
11namespace ck_tile {
12
13template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
15{
20
24
27
28 static constexpr auto config =
29 BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
30
31 using WG = remove_cvref_t<decltype(config.template at<0>())>;
32
33 static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
34 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
35
36 static constexpr index_t BlockSize = Problem::kBlockSize;
37 static constexpr index_t WaveSize = get_warp_size();
38
39 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
40 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
41 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
42
43 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
44 static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
45
46 static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
47 static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
48 static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
49
50 static constexpr bool kPadM = Problem::kPadM;
51 static constexpr bool kPadN = Problem::kPadN;
52 static constexpr bool kPadK = Problem::kPadK;
53
54 static constexpr index_t kLdsAlignmentInBytes = 16;
55 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
56 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
57
58 static constexpr auto I0 = number<0>();
59 static constexpr auto I1 = number<1>();
60 static constexpr auto I2 = number<2>();
61 static constexpr auto idxM = I0;
62 static constexpr auto idxN = I1;
63 static constexpr auto idxK = I2;
67
68 static constexpr index_t MWarp = config.template at<1>();
69 static constexpr index_t NWarp = config.template at<2>();
70
71 static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
72 static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
73 static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
74
77
80
81 static constexpr int MXFP4PackedSize = 2;
82 static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType);
83 static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * MXFP4PackedSize;
87
88 static constexpr bool HasHotLoop = Problem::HasHotLoop;
89 static constexpr auto TailNum = Problem::TailNum;
90
91#ifdef __gfx942__
92 static constexpr index_t mfma_per_wg = 2;
93#else
94 static constexpr index_t mfma_per_wg = 1;
95#endif
96 static constexpr index_t dsread_per_wg =
97 WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize;
98 static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0);
99
104 static constexpr index_t Aload_rep = dswrite_rep;
105 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
106 // static constexpr index_t ScaleBload_K1 = ContinuousScaleNPerThread *
107 // ContinuousScaleKPerThread; static constexpr index_t ScaleBload_num =
108 // kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 /
109 // WaveSize; // BlockN * BlockK / NWarp / ScalePerK / ScaleB_K1 / wavesize
110 // static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
111 static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
112 static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
113
117
118 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
119 {
120 // clang-format off
121 return concat('_', "pipeline_AGmemBGmemCRegV1",
123 concat('x', WG::kM, WG::kN, WG::kK),
125 concat('x', kPadM, kPadN, kPadK));
126 // clang-format on
127 }
128
129 // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
130 static constexpr bool DoubleSmemBuffer = false;
131
132 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
133
135 {
136 return PipelinePolicy::template GetSmemSize<Problem>();
137 }
138
139 CK_TILE_HOST_DEVICE static constexpr auto
140 SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
141 {
142 // Init inst order
143 index_t max_data_inst = dsread_perM > load_perM
144 ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
145 : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
146 index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
147 index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK;
148
149 index_t inst_order[NIterPerWarp * 10];
150 _Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; }
151
152 index_t index = 0;
153 _Pragma("unroll") for(int j = 0; j < max_data_inst; j++)
154 {
155 if(dswrite_perM > j)
156 {
157 inst_order[index] = 1;
158 index++;
159 }
160 if(load_perM > j)
161 {
162 inst_order[index] = 2;
163 index++;
164 }
165 if(dsread_perM > j)
166 {
167 inst_order[index] = 3;
168 index++;
169 }
170 }
171
172 // Schedule IGLP
173 _Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++)
174 {
175 index_t inst_idx = 0;
176 if(j == 0)
177 ;
178 else if(j == 1)
179 inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
180 else if(j == 2)
181 inst_idx = mfma_perM_perK - 1;
182 else
183 inst_idx = mfma_perM_perK - j;
184
185 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
186
187 _Pragma("unroll") for(int r = 0; r < round_data_inst; r++)
188 {
189 if(r % 2 == 0)
190 {
191 if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
192 {
193 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
194 }
195 if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
196 {
197 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
198 }
199 if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
200 {
201 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
202 }
203 }
204 else
205 {
206 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
207 {
208 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
209 }
210 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
211 {
212 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
213 }
214 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
215 {
216 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
217 }
218 }
219 }
220 }
221 }
222 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
223 {
224 // Keypoint of pipeline optimize is workload balance in time
225 // instruction schedule example(128X256X256, 1X4, 16X16X128):
226 // Iter MNK MFMA ds_read ds_write A_load b_load
227 // -1 M6N0: 57 - 8 - -
228 // -1 M6N1: 58 1 - - -
229 // -1 M6N2: 59 - - 7 -
230 // -1 M6N3: 60 2 - - -
231 // -1 M7N0: 61 - - - -
232 // -1 M7N1: 62 3 - - -
233 // -1 M7N2: 63 - - 8 -
234 // -1 M7N3: 64 4 - - -
235 // 0 M0N0K0: 1 - - - 1
236 // 0 M0N1: 2 5 - - -
237 // 0 M0N2: 3 - - - 2
238 // 0 M0N3: 4 6 - - -
239 // 0 M1N0: 5 - - - 3
240 // 0 M1N1: 6 7 - - -
241 // 0 M1N2: 7 - - - 4
242 // 0 M1N3: 8 8 - - -
243 // 0 M2N0: 9 - - - 5
244 // 0 M2N1: 10 9 - - -
245 // 0 M2N2: 11 - - - 6
246 // 0 M2N3: 12 10 - - -
247 // 0 M3N0: 13 - 1 - 7
248 // 0 M3N1: 14 11 - - -
249 // 0 M3N2: 15 - - - 8
250 // 0 M3N3: 16 12 - - -
251 // 0 M4N0: 17 - 2 - -
252 // 0 M4N1: 18 13 - - -
253 // 0 M4N2: 19 - - 1 -
254 // 0 M4N3: 20 14 - - -
255 // 0 M5N0: 21 - 3 - -
256 // 0 M5N1: 22 15 - - -
257 // 0 M5N2: 23 - - 2 -
258 // 0 M5N3: 24 16 - - -
259 // 0 M6N0: 25 - 4 - -
260 // 0 M6N1: 26 17 - - -
261 // 0 M6N2: 27 - - 3 -
262 // 0 M6N3: 28 18 - - -
263 // 0 M7N0: 29 - - - -
264 // 0 M7N1: 30 19 - - -
265 // 0 M7N2: 31 - - 4 -
266 // 0 M7N3: 32 20 - - -
267 // 0 M0N0K1: 33 - - - 9
268 // 0 M0N1: 34 21 - - -
269 // 0 M0N2: 35 - - - 10
270 // 0 M0N3: 36 22 - - -
271 // 0 M1N0: 37 - - - 11
272 // 0 M1N1: 38 23 - - -
273 // 0 M1N2: 39 - - - 12
274 // 0 M1N3: 40 24 - - -
275 // 0 M2N0: 41 - - - 13
276 // 0 M2N1: 42 25 - - -
277 // 0 M2N2: 43 - - - 14
278 // 0 M2N3: 44 26 - - -
279 // 0 M3N0: 45 - 5 - 15
280 // 0 M3N1: 46 27 - - -
281 // 0 M3N2: 47 - - - 16
282 // 0 M3N3: 48 28 - - -
283 // 0 M4N0: 49 - 6 - -
284 // 0 M4N1: 50 29 - - -
285 // 0 M4N2: 51 - - 5 -
286 // 0 M4N3: 52 30 - - -
287 // 0 M5N0: 53 - 7 - -
288 // 0 M5N1: 54 31 - - -
289 // 0 M5N2: 55 - - 6 -
290 // 0 M5N3: 56 32 - - -
291 // 0 M6N0: 57 - 8 - -
292 // 0 M6N1: 58 1 - - -
293 // 0 M6N2: 59 - - 7 -
294 // 0 M6N3: 60 2 - - -
295 // 0 M7N0: 61 - - - -
296 // 0 M7N1: 62 3 - - -
297 // 0 M7N2: 63 - - 8 -
298 // 0 M7N3: 64 4 - - -
299
300 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
301 {
302 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
303 {
304 index_t dsread_perM = 0;
305 index_t dswrite_perM = 0;
306 index_t load_perM = 0;
307
308 // Calculate ds_read number per M
309 dsread_perM = dsread_per_wg;
310
311 // Calculate ds_write number per M
312 if(mIter == 0)
313 {
314 dswrite_perM =
317 : 0;
318 }
319 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
320 {
321 dswrite_perM = 0;
322 }
323 else
324 {
325 dswrite_perM = (dswrite_num_perK -
326 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
328 : 0;
329 }
330 // Add ds write when ds write data > needed
331 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
332 {
333 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
334 dswrite_perM = 1;
335 }
336
337 // Calculate buffer_load number per M
338 if(mIter < HalfMIter)
339 {
340 load_perM =
341 ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
342 : 0) +
343 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
344 : 0);
345 }
346 else
347 {
348 load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
349 ? Aload_rep
350 : 0;
351 }
352 // if((kIter % KPerScaleLoad == 0) && (mIter == 0))
353 // {
354 // load_perM = load_perM + 1;
355 // }
356 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
357 }
358 }
359 // Add Aload when Aload data > needed
360 if(Aload_num_perK == 0)
361 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
362 __builtin_amdgcn_sched_barrier(0);
363 }
364
366 {
367 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
368 {
369 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
370 {
371 index_t dsread_perM = 0;
372 index_t dswrite_perM = 0;
373 index_t load_perM = 0;
374
375 // Calculate ds_read number per M
376 dsread_perM = dsread_per_wg;
377
378 // Calculate ds_write number per M
379 if(mIter == 0)
380 {
381 dswrite_perM =
384 : 0;
385 }
386 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
387 {
388 dswrite_perM = 0;
389 }
390 else
391 {
392 dswrite_perM = (dswrite_num_perK -
393 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
395 : 0;
396 }
397 // Add ds write when ds write data > needed
398 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
399 {
400 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
401 dswrite_perM = 1;
402 }
403
404 // Calculate buffer_load number per M
405 if(mIter < HalfMIter)
406 {
407 load_perM =
408 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
409 : 0);
410 }
411 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
412 }
413 }
414 __builtin_amdgcn_sched_barrier(0);
415 }
416
418 {
419 _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++)
420 {
421 _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++)
422 {
423 index_t dsread_perM = 0;
424 index_t dswrite_perM = 0;
425 index_t load_perM = 0;
426
427 // Calculate ds_read number per M
428 if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
429 dsread_perM = dsread_per_wg;
430
431 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
432 }
433 }
434 // __builtin_amdgcn_sched_barrier(0);
435 }
436
438 {
439 return PipelinePolicy::template MakeADramTileDistribution<Problem>();
440 }
441
442 template <typename ADramBlockWindowTmp,
443 typename AElementFunction,
444 typename BFlatBlockWindowTmp,
445 int IsGateUpMode>
446 CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
447 const AElementFunction& a_element_func,
448 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
450 index_t num_loop,
451 void* p_smem_ping,
452 void* p_smem_pong) const
453 {
454 static_assert(
455 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
456 "wrong!");
457
458 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
459 "wrong!");
460 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
461 "wrong!");
462
463 constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
464 const index_t iMWarp = get_warp_id() / NWarp;
465
466 using CWarpDstr = typename WG::CWarpDstr;
467 using CWarpTensor = typename WG::CWarpTensor;
468
469 constexpr auto c_warp_y_lengths =
470 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
471 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
472
473 __builtin_amdgcn_sched_barrier(0);
474
475 // A tile in LDS
476 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
477 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
478
479 constexpr auto a_lds_block_desc =
480 PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
481
482 auto a_lds_block_ping =
483 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
484 auto a_lds_block_pong =
485 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
486
487 auto a_copy_dram_window = ck_tile::make_tile_scatter_gather(
488 a_dram_block_window_tmp.get_bottom_tensor_view(),
490 a_dram_block_window_tmp.get_window_origin(),
491 PipelinePolicy::template MakeADramTileDistribution<Problem>(),
492 a_dram_block_window_tmp.page_idx_); // K DRAM tile window for
493
494 auto a_copy_lds_window_ping =
495 make_tile_window(a_lds_block_ping,
497 {0, 0},
498 PipelinePolicy::template MakeADramTileDistribution<Problem>());
499
500 auto a_copy_lds_window_pong =
501 make_tile_window(a_lds_block_pong,
503 {0, 0},
504 PipelinePolicy::template MakeADramTileDistribution<Problem>());
505
506 // ping-pong window for A LDS
507 auto a_warp_window_ping_tmp =
508 make_tile_window(a_lds_block_ping,
510 {iMWarp * WG::kM, 0},
511 PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
512
513 auto a_warp_window_pong_tmp =
514 make_tile_window(a_lds_block_pong,
516 {iMWarp * WG::kM, 0},
517 PipelinePolicy::template MakeALDS_WarpTileDistribution<Problem>());
518
520 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
522 a_warp_windows_ping;
523
525 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
527 a_warp_windows_pong;
528
529 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
530 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
531 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
532
533 move_tile_window(a_warp_windows_ping(mIter)(kIter),
534 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
535 });
536 });
537
538 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
539 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
540 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
541
542 move_tile_window(a_warp_windows_pong(mIter)(kIter),
543 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
544 });
545 });
546
547 // Block GEMM
548 auto block_flatmm = BlockFlatmm();
549 // Acc register tile
550 auto c_block_tile = block_flatmm.MakeCBlockTile();
551
552 // B flat DRAM window for load
553 auto b_flat_distribution =
554 PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
555 auto b_flat_dram_window = // tile_window_with_static_distribution
557 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
559 b_flat_dram_block_window_tmp.get_window_origin(),
560 b_flat_distribution);
561
562 // pingpong buffer for B
564 statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
566 b_flat_dram_windows;
567
569 statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
571 b_warp_tensor_ping;
572
574 statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
576 b_warp_tensor_pong;
577
578 // HEAD
579 // Prefetch A0
580 auto a_block_tile = load_tile(a_copy_dram_window);
581 // move A window to next k
582 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
583
584 if constexpr(IsGateUpMode)
585 static_assert(NIterPerWarp % 2 == 0);
586 auto up_weight_stride = b_flat_dram_window.get_bottom_tensor_view()
587 .get_tensor_descriptor()
588 .get_lengths()[number<0>{}] /
589 2;
590
591 // prefetch B
592 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
593 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
594 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
595
596 if constexpr(!IsGateUpMode)
597 move_tile_window(b_flat_dram_windows(nIter)(kIter),
598 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
599 else
600 {
601 if constexpr(nIter % 2 == 0)
603 b_flat_dram_windows(nIter)(kIter),
604 {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
605 else
606 move_tile_window(b_flat_dram_windows(nIter)(kIter),
607 {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
608 kIter * KFlatPerBlockPerIter});
609 }
610 b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
611 });
612 });
613 // move B window to next flat K
614 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
615
616 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
617 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
618 __builtin_amdgcn_sched_barrier(0);
619
620 // Prefetch A1
621 a_block_tile = load_tile(a_copy_dram_window);
622 // move A window to next k
623 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
624
625 // initialize C
626 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
627
629
630 // preload A00,A10... from lds
631 statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
632 m_preload>
633 a_warp_tensor;
634
635 static_for<0, m_preload, 1>{}([&](auto loadIter) {
636 constexpr auto mIter = loadIter % MIterPerWarp;
637 constexpr auto kIter = loadIter / MIterPerWarp;
638 a_warp_tensor(loadIter) =
639 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
640 });
641 __builtin_amdgcn_sched_barrier(0);
642
643 // MAIN LOOP
644 index_t iCounter = (num_loop - 1) / 2;
645 while(iCounter > 0)
646 {
647 // prefetch B(2i+1)
648 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
649 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
650 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
651
652 if constexpr(!IsGateUpMode)
654 b_flat_dram_windows(nIter)(kIter),
655 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
656 else
657 {
658 if constexpr(nIter % 2 == 0)
660 b_flat_dram_windows(nIter)(kIter),
661 {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
662 else
663 move_tile_window(b_flat_dram_windows(nIter)(kIter),
664 {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
665 kIter * KFlatPerBlockPerIter});
666 }
667
668 b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
669 });
670 });
671
672 // Prefill A(2i+1)
673 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
674 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
675
676 // Prefetch A(2i+2)
677 a_block_tile = load_tile(a_copy_dram_window);
678 // move A window to next k
679 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
680
681 // GEMM 2i
682 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
683 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
684 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
685 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
686 // read C warp tensor from C block tensor
687 CWarpTensor c_warp_tensor;
688
689 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
690 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
691 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
692
693 // warp GEMM
694 WG{}(c_warp_tensor,
695 a_warp_tensor(number<AwarpIter>{}),
696 b_warp_tensor_ping(nIter)(kIter));
697
698 // write C warp tensor into C block tensor
699 c_block_tile.set_y_sliced_thread_data(
700 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
701 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
702 c_warp_tensor.get_thread_buffer());
703 });
704 // preload next A from lds
705 if constexpr((kIter * MIterPerWarp + mIter) <
707 {
708 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
709 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
710 a_warp_tensor(number<AwarpIter>{}) =
711 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
712 }
713
714 // barrier
715 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
716 {
718 }
719 });
720 });
721
722 // move B window to next flat K
723 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
724
725 static_for<0, m_preload, 1>{}([&](auto loadIter) {
726 constexpr auto mIter = loadIter % MIterPerWarp;
727 constexpr auto kIter = loadIter / MIterPerWarp;
728 a_warp_tensor(loadIter) =
729 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
730 });
732
733 // Next K
734
735 // prefetch B(2i+2)
736 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
737 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
738 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
739
740 if constexpr(!IsGateUpMode)
742 b_flat_dram_windows(nIter)(kIter),
743 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
744 else
745 {
746 if constexpr(nIter % 2 == 0)
748 b_flat_dram_windows(nIter)(kIter),
749 {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
750 else
751 move_tile_window(b_flat_dram_windows(nIter)(kIter),
752 {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
753 kIter * KFlatPerBlockPerIter});
754 }
755
756 b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
757 });
758 });
759
760 // Prefill A(2i+2)
761 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
762 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
763
764 // Prefetch A(2i+3)
765 a_block_tile = load_tile(a_copy_dram_window);
766 // move A window to next k
767 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
768
769 // GEMM 2i+1
770 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
771 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
772 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
773 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
774 // read C warp tensor from C block tensor
775 CWarpTensor c_warp_tensor;
776 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
777 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
778 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
779
780 // warp GEMM
781 WG{}(c_warp_tensor,
782 a_warp_tensor(number<AwarpIter>{}),
783 b_warp_tensor_pong(nIter)(kIter));
784
785 // write C warp tensor into C block tensor
786 c_block_tile.set_y_sliced_thread_data(
787 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
788 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
789 c_warp_tensor.get_thread_buffer());
790 });
791 // preload next A from lds
792 if constexpr((kIter * MIterPerWarp + mIter) <
794 {
795 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
796 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
797 a_warp_tensor(number<AwarpIter>{}) =
798 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
799 }
800
801 // barrier
802 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
803 {
805 }
806 });
807 });
808
809 // move B window to next flat K
810 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
811
812 static_for<0, m_preload, 1>{}([&](auto loadIter) {
813 constexpr auto mIter = loadIter % MIterPerWarp;
814 constexpr auto kIter = loadIter / MIterPerWarp;
815 a_warp_tensor(loadIter) =
816 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
817 });
819
820 iCounter--;
821 }
822
823 // TAIL
824 if constexpr(TailNum == TailNumber::Even)
825 {
826 // prefetch B(loopK)
827 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
828 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
829 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
830
831 if constexpr(!IsGateUpMode)
833 b_flat_dram_windows(nIter)(kIter),
834 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
835 else
836 {
837 if constexpr(nIter % 2 == 0)
839 b_flat_dram_windows(nIter)(kIter),
840 {nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
841 else
842 move_tile_window(b_flat_dram_windows(nIter)(kIter),
843 {nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
844 kIter * KFlatPerBlockPerIter});
845 }
846
847 b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
848 });
849 });
850
851 // Prefill A(loopK)
852 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
853 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
854
855 // GEMM loopK-1
856 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
857 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
858 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
859 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
860 // read C warp tensor from C block tensor
861 CWarpTensor c_warp_tensor;
862
863 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
864 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
865 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
866
867 // warp GEMM
868 WG{}(c_warp_tensor,
869 a_warp_tensor(number<AwarpIter>{}),
870 b_warp_tensor_ping(nIter)(kIter));
871
872 // write C warp tensor into C block tensor
873 c_block_tile.set_y_sliced_thread_data(
874 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
875 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
876 c_warp_tensor.get_thread_buffer());
877 });
878 // preload next A from lds
879 if constexpr((kIter * MIterPerWarp + mIter) <
881 {
882 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
883 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
884 a_warp_tensor(number<AwarpIter>{}) =
885 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
886 }
887
888 // barrier
889 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
890 {
892 }
893 });
894 });
895
896 static_for<0, m_preload, 1>{}([&](auto loadIter) {
897 constexpr auto mIter = loadIter % MIterPerWarp;
898 constexpr auto kIter = loadIter / MIterPerWarp;
899 a_warp_tensor(loadIter) =
900 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
901 });
902
904
905 // GEMM loopK
906 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
907 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
908 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
909 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
910 // read C warp tensor from C block tensor
911 CWarpTensor c_warp_tensor;
912
913 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
914 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
915 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
916
917 // warp GEMM
918 WG{}(c_warp_tensor,
919 a_warp_tensor(number<AwarpIter>{}),
920 b_warp_tensor_pong(nIter)(kIter));
921
922 // write C warp tensor into C block tensor
923 c_block_tile.set_y_sliced_thread_data(
924 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
925 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
926 c_warp_tensor.get_thread_buffer());
927 });
928 if constexpr((kIter * MIterPerWarp + mIter) <
930 {
931 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
932 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
933 a_warp_tensor(number<AwarpIter>{}) =
934 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
935 }
936 // barrier
937 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
938 {
940 }
941 });
942 });
944 }
945 else if constexpr(TailNum == TailNumber::Odd)
946 {
947 // GEMM loopK
948 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
949 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
950 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
951 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
952 // read C warp tensor from C block tensor
953 CWarpTensor c_warp_tensor;
954
955 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
956 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
957 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
958
959 // warp GEMM
960 WG{}(c_warp_tensor,
961 a_warp_tensor(number<AwarpIter>{}),
962 b_warp_tensor_ping(nIter)(kIter));
963
964 // write C warp tensor into C block tensor
965 c_block_tile.set_y_sliced_thread_data(
966 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
967 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
968 c_warp_tensor.get_thread_buffer());
969 });
970 // preload next A from lds
971 if constexpr((kIter * MIterPerWarp + mIter) <
973 {
974 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
975 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
976 a_warp_tensor(number<AwarpIter>{}) =
977 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
978 }
979
980 // barrier
981 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
982 {
984 }
985 });
986 });
988 }
989
990 return c_block_tile;
991 }
992
993 template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, int IsGateUpMode>
994 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
995 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
996 number<IsGateUpMode> is_gate_up_mode,
997 index_t num_loop,
998 void* p_smem_ping,
999 void* p_smem_pong) const
1000 {
1001 return operator()(
1002 a_dram_block_window_tmp,
1003 [](const ADataType & a) { return a; },
1004 b_flat_dram_block_window_tmp,
1005 is_gate_up_mode,
1006 num_loop,
1007 p_smem_ping,
1008 p_smem_pong);
1009 }
1010};
1011
1012} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE constexpr auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition tile_scatter_gather.hpp:906
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:15
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:65
static constexpr index_t Bload_num_perK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:105
static constexpr index_t kNPerBlock
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:40
static constexpr index_t KIterPerWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:73
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, number< IsGateUpMode > is_gate_up_mode, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:994
static constexpr index_t BlockSize
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:36
static constexpr index_t Aload_rep
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:104
static constexpr auto idxN
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:62
static constexpr bool DoubleSmemBuffer
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:130
static constexpr auto config
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:28
static constexpr index_t NWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:69
static constexpr bool UsePersistentKernel
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:56
static constexpr index_t DsReadPreload
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:34
remove_cvref_t< decltype(PipelinePolicy::template GetBlockFlatmm< Problem >())> BlockFlatmm
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:25
static constexpr index_t flatNPerWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:44
static constexpr index_t KPerBlockPerIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:79
static constexpr index_t m_preload
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:84
static constexpr index_t MWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:68
static constexpr index_t dswrite_rep
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:102
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:66
static constexpr index_t mfma_per_wg
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:94
static CK_TILE_HOST_DEVICE constexpr auto GetADramTileDistribution()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:437
static constexpr index_t WaveSize
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:37
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:19
static constexpr index_t AK1
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:82
static constexpr index_t GetVectorSizeC()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:48
static constexpr auto idxK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:63
static constexpr index_t mfma_perM_perK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:114
static constexpr index_t KFlatPerBlockPerIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:75
static constexpr index_t kKPerBlock
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:41
static constexpr bool kPadK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:52
static CK_TILE_HOST_DEVICE constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:140
static constexpr bool kPadM
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:50
remove_cvref_t< typename Problem::ADataType > ADataType
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:16
static constexpr index_t HalfMIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:111
static constexpr int MXFP4PackedSize
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:81
static constexpr index_t GetVectorSizeA()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:46
static constexpr index_t dsread_num_perK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:100
static constexpr index_t kLdsAlignmentInBytes
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:54
static constexpr index_t NIterPerWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:72
static constexpr index_t DsWritePreIssue
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:33
static constexpr index_t flatKPerWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:43
static constexpr index_t Aload_num_perK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:103
static constexpr index_t dswrite_mIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:115
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:417
static constexpr index_t dswrite_num_perK
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:101
static constexpr index_t BK1
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:83
static constexpr index_t kMPerBlock
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:39
static constexpr index_t NumWaveGroups
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:55
static constexpr auto I1
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:59
static constexpr bool HasHotLoop
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:88
static constexpr auto idxM
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:61
static constexpr auto TailNum
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:89
static constexpr index_t dsread_per_wg
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:96
static constexpr index_t MIterPerWarp
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:71
remove_cvref_t< typename Problem::CLayout > CLayout
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:23
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:132
static constexpr bool kPadN
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:51
static CK_TILE_HOST const std::string GetName()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:118
static constexpr auto I2
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:60
static constexpr index_t MPerBlockPerIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:78
static constexpr index_t NFlatPerBlockPerIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:76
remove_cvref_t< typename Problem::BDataType > BDataType
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:17
static constexpr index_t dswrite_kIter
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:116
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:222
remove_cvref_t< typename Problem::BLayout > BLayout
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:22
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:64
static constexpr index_t GetVectorSizeB()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:47
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, number< IsGateUpMode >, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:446
static constexpr index_t Bload_rep
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:112
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:134
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:365
static constexpr auto I0
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:58
remove_cvref_t< typename Problem::CDataType > CDataType
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:18
remove_cvref_t< typename Problem::ALayout > ALayout
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:21
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition moe_flatmm_pipeline_agmem_bgmem_creg.hpp:31
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43