device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp Source File

device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp Source File#

Composable Kernel: device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp Source File
device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemmWelford,
24 typename ABDataType,
25 typename DsPointer,
26 typename EMeanVarDataType,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
35 typename CountGridDescriptor_MBlock_MPerBlock_NBlock,
36 typename Block2ETileMap,
37 bool HasMainKBlockLoop>
38__global__ void
39#if CK_USE_LAUNCH_BOUNDS
41#endif
43 const ABDataType* __restrict__ p_a_grid,
44 const ABDataType* __restrict__ p_b_grid,
45 DsPointer p_ds_grid,
46 EMeanVarDataType* __restrict__ p_e_grid,
47 EMeanVarDataType* __restrict__ p_welford_mean_grid,
48 EMeanVarDataType* __restrict__ p_welford_var_grid,
49 int32_t* __restrict__ p_welford_count_grid,
50 const AElementwiseOperation a_element_op,
51 const BElementwiseOperation b_element_op,
52 const CDEElementwiseOperation cde_element_op,
53 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
54 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
55 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
56 ds_grid_desc_mblock_mperblock_nblock_nperblock,
57 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
58 e_grid_desc_mblock_mperblock_nblock_nperblock,
59 const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock
60 mean_var_grid_desc_mblock_mperblock_nblock,
61 const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock,
62 const Block2ETileMap block_2_etile_map,
63 index_t NRaw)
64{
65#if defined(__gfx9__) || defined(__gfx12__)
66 if constexpr(GridwiseGemmWelford::template IsValidCompilationParameter<>())
67 {
68 __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()];
69
70 GridwiseGemmWelford::template Run<HasMainKBlockLoop>(
71 p_a_grid,
72 p_b_grid,
73 p_ds_grid,
74 p_e_grid,
75 p_welford_mean_grid,
76 p_welford_var_grid,
77 p_welford_count_grid,
78 p_shared,
79 a_element_op,
80 b_element_op,
81 cde_element_op,
82 a_grid_desc_ak0_m_ak1,
83 b_grid_desc_bk0_n_bk1,
84 ds_grid_desc_mblock_mperblock_nblock_nperblock,
85 e_grid_desc_mblock_mperblock_nblock_nperblock,
86 mean_var_grid_desc_mblock_mperblock_nblock,
87 count_grid_desc_mblock_mperblock_nblock,
88 block_2_etile_map,
89 NRaw);
90 }
91#else
92 ignore = p_a_grid;
93 ignore = p_b_grid;
94 ignore = p_ds_grid;
95 ignore = p_e_grid;
96 ignore = p_welford_mean_grid;
97 ignore = p_welford_var_grid;
98 ignore = p_welford_count_grid;
99 ignore = a_element_op;
100 ignore = b_element_op;
101 ignore = cde_element_op;
102 ignore = a_grid_desc_ak0_m_ak1;
103 ignore = b_grid_desc_bk0_n_bk1;
104 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
105 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
106 ignore = mean_var_grid_desc_mblock_mperblock_nblock;
107 ignore = count_grid_desc_mblock_mperblock_nblock;
108 ignore = block_2_etile_map;
109 ignore = NRaw;
110#endif
111}
112
113template <typename GridwiseWelfordLayernorm,
114 typename EMeanVarDataType,
115 typename HDataType,
116 typename GammaDataType,
117 typename BetaDataType,
118 typename ComputeDataType,
119 typename EHGridDesc_M_N,
120 typename LayernormMeanVarGridDesc_M_NBlock,
121 typename LayernormCountGridDesc_M_NBlock,
122 typename GammaBetaGridDesc_N,
123 typename HElementwiseOperation>
124__global__ void
125#if CK_USE_LAUNCH_BOUNDS
127#endif
129 const EMeanVarDataType* __restrict__ p_e_grid,
130 const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
131 const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
132 const int32_t* __restrict__ p_in_welford_count_grid,
133 const GammaDataType* __restrict__ p_gamma_grid,
134 const BetaDataType* __restrict__ p_beta_grid,
135 HDataType* __restrict__ p_h_grid,
136 const EHGridDesc_M_N e_grid_desc_m_n,
137 const EHGridDesc_M_N h_grid_desc_m_n,
138 const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
139 const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
140 const GammaBetaGridDesc_N gamma_grid_desc_n,
141 const GammaBetaGridDesc_N beta_grid_desc_n,
142 index_t numMeanVarCountBlockTileIteration_N,
143 index_t NBlockClusterLength,
144 ComputeDataType epsilon,
145 HElementwiseOperation h_element_op)
146{
147 GridwiseWelfordLayernorm::Run(p_e_grid,
148 p_in_welford_mean_grid,
149 p_in_welford_var_grid,
150 p_in_welford_count_grid,
151 p_gamma_grid,
152 p_beta_grid,
153 p_h_grid,
154 e_grid_desc_m_n,
155 h_grid_desc_m_n,
156 mean_var_grid_desc_m_nblock,
157 count_grid_desc_m_nblock,
158 gamma_grid_desc_n,
159 beta_grid_desc_n,
160 numMeanVarCountBlockTileIteration_N,
161 NBlockClusterLength,
162 epsilon,
163 h_element_op);
164}
165
166} // namespace ck
167
168namespace ck {
169namespace tensor_operation {
170namespace device {
171
172// GEMM:
173// input : A[M, K]
174// input : B[N, K]
175// input : D0[M, N], D1[M, N], ...
176// output : E[M, N]
177// output : H[M, N]
178// C = a_op(A) * b_op(B)
179// E = cde_op(C, D0, D1, ...)
180// H = layernorm(E)
181// Assume:
182// D0, D1, ... and E have the same layout
183// Calculate mean & variance along N dimension in layernorm(E)
184template <typename ALayout,
185 typename BLayout,
186 typename DsLayout,
187 typename HLayout,
188 typename ADataType,
189 typename BDataType,
190 typename AccDataType,
191 typename CShuffleDataType,
192 typename DsDataType,
193 typename EMeanVarDataType,
194 typename GammaDataType,
195 typename BetaDataType,
196 typename HDataType,
197 typename AElementwiseOperation,
198 typename BElementwiseOperation,
199 typename CDEElementwiseOperation,
200 typename HElementwiseOperation,
201 GemmSpecialization GemmSpec,
202 index_t NumGemmKPrefetchStage,
203 index_t BlockSize,
204 index_t MPerBlock,
205 index_t NPerBlock,
206 index_t KPerBlock,
207 index_t AK1,
208 index_t BK1,
209 index_t MPerXDL,
210 index_t NPerXDL,
211 index_t MXdlPerWave,
212 index_t NXdlPerWave,
213 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214 typename ABlockTransferThreadClusterArrangeOrder,
215 typename ABlockTransferSrcAccessOrder,
216 index_t ABlockTransferSrcVectorDim,
217 index_t ABlockTransferSrcScalarPerVector,
218 index_t ABlockTransferDstScalarPerVector_AK1,
219 bool ABlockLdsExtraM,
220 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
221 typename BBlockTransferThreadClusterArrangeOrder,
222 typename BBlockTransferSrcAccessOrder,
223 index_t BBlockTransferSrcVectorDim,
224 index_t BBlockTransferSrcScalarPerVector,
225 index_t BBlockTransferDstScalarPerVector_BK1,
226 bool BBlockLdsExtraN,
227 index_t CShuffleMXdlPerWavePerShuffle,
228 index_t CShuffleNXdlPerWavePerShuffle,
229 typename PostShuffleThreadClusterSize_M_N,
230 index_t PostShuffleScalarPerVector,
231 typename LayernormThreadClusterSize_M_N,
232 index_t LayernormThreadSliceSize_M,
236 : public DeviceGemmMultipleDLayernorm<ALayout,
237 BLayout,
238 DsLayout,
239 HLayout,
240 ADataType,
241 BDataType,
242 DsDataType,
243 GammaDataType,
244 BetaDataType,
245 HDataType,
246 AElementwiseOperation,
247 BElementwiseOperation,
248 CDEElementwiseOperation,
249 HElementwiseOperation>
250{
251 // EDataType, MeanDataType and VarDataType must be the same.
252 // eg. M, N, K = [1, 1, 1],
253 // in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
254 // if (x - mean) != 0, (x - mean) * divisor * gamma might be too large
255 // However, (x - mean) * divisor * gamma should be 0 in this case
256
258 using ELayout = HLayout;
259
261 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
262 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
263
264 static constexpr index_t NumDTensor = DsDataType::Size();
265 static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector;
266 static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector;
267 static constexpr index_t LayernormBetaSrcVectorSize = PostShuffleScalarPerVector;
268 static constexpr index_t LayernormESrcVectorSize = PostShuffleScalarPerVector;
269 static constexpr index_t LayernormThreadSliceSize_N = PostShuffleScalarPerVector;
271 Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
272 LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_N>;
273
274 static constexpr auto I0 = Number<0>{};
275 static constexpr auto I1 = Number<1>{};
276 static constexpr auto I2 = Number<2>{};
277
278 static constexpr auto matrix_padder =
279 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
280
281 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
282 {
283 const auto a_grid_desc_mraw_kraw = [&]() {
285 {
286 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
287 make_tuple(StrideA, I1));
288 }
290 {
291 return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
292 make_tuple(I1, StrideA));
293 }
294 }();
295
296 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
297 }
298
299 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
300 {
301 const auto b_grid_desc_nraw_kraw = [&]() {
303 {
304 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
305 make_tuple(I1, StrideB));
306 }
308 {
309 return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
310 make_tuple(StrideB, I1));
311 }
312 }();
313
314 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
315 }
316
317 template <typename DoPads, index_t MPerTile, index_t NPerTile>
319 {
320 // Only support row major for E and H
321 const auto grid_desc_m_n =
323 return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
324 }
325
326 static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
327 const std::array<index_t, NumDTensor>& NRaws,
328 const std::array<index_t, NumDTensor>& DsStride)
329 {
330 return generate_tuple(
331 [&](auto i) {
332 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
334
335 return DeviceOp::
336 MakeEHGridDescriptor_M_N<Sequence<true, true>, MPerBlock, NPerBlock>(
337 MRaws[i], NRaws[i], DsStride[i]);
338 },
340 }
341
342 template <typename DoPads, index_t MPerTile, index_t NPerTile>
344 {
345 const auto grid_desc_m_n =
347 return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
348 }
349
350 template <typename DoPads, index_t MPerTile, index_t NPerTile>
352 {
353 // We will broadcast [N] to [M, N] in this descriptor
354 // Hence, 1st stride is 0
355 const auto grid_desc_m_n =
357 return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
358 }
359
360 template <index_t XPerTile>
362 {
363 const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
364 return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
365 }
366
367 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
368 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
370 // We have to separate mean var descriptor for gemm and layernorm bacause of different grid
371 // layout(different padding)
373 decltype(MakeMeanVarDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
374
376 decltype(MakeCountDescriptor_M_N<Sequence<true, false>, MPerBlock, NPerBlock>(1, 1));
377
382
387
390
391 template <index_t NXdlPerWave_>
393 ADataType, // TODO: distinguish A/B datatype
394 AccDataType,
395 CShuffleDataType,
396 DsDataType,
397 EMeanVarDataType,
398 AElementwiseOperation,
399 BElementwiseOperation,
400 CDEElementwiseOperation,
408 NumGemmKPrefetchStage,
409 BlockSize,
410 MPerBlock,
411 NPerBlock,
412 KPerBlock,
413 AK1,
414 BK1,
415 MPerXDL,
416 NPerXDL,
417 MXdlPerWave,
418 NXdlPerWave_,
419 ABlockTransferThreadClusterLengths_AK0_M_AK1,
420 ABlockTransferThreadClusterArrangeOrder,
421 ABlockTransferSrcAccessOrder,
422 ABlockTransferSrcVectorDim,
423 ABlockTransferSrcScalarPerVector,
424 ABlockTransferDstScalarPerVector_AK1,
425 false,
426 ABlockLdsExtraM,
427 BBlockTransferThreadClusterLengths_BK0_N_BK1,
428 BBlockTransferThreadClusterArrangeOrder,
429 BBlockTransferSrcAccessOrder,
430 BBlockTransferSrcVectorDim,
431 BBlockTransferSrcScalarPerVector,
432 BBlockTransferDstScalarPerVector_BK1,
433 false,
434 BBlockLdsExtraN,
435 CShuffleMXdlPerWavePerShuffle,
436 CShuffleNXdlPerWavePerShuffle,
437 PostShuffleThreadClusterSize_M_N,
438 PostShuffleScalarPerVector,
439 LoopSched,
440 PipelineVer>;
443
445
448 HDataType,
449 GammaDataType,
450 BetaDataType,
451 AccDataType,
456 HElementwiseOperation,
457 BlockSize,
458 LayernormThreadClusterSize_M_N::At(I0),
459 LayernormThreadClusterSize_M_N::At(I1),
460 LayernormThreadSliceSize_M,
466
467 // Argument
468 struct Argument : public BaseArgument
469 {
470 Argument(const void* p_a_grid,
471 const void* p_b_grid,
472 std::array<const void*, NumDTensor> p_ds_grid,
473 const void* p_gamma_grid,
474 const void* p_beta_grid,
475 void* p_h_grid,
476 index_t MRaw,
477 index_t NRaw,
478 index_t KRaw,
479 index_t StrideA,
480 index_t StrideB,
481 std::array<index_t, NumDTensor> StrideDs,
482 index_t StrideH,
483 double epsilon,
484 AElementwiseOperation a_element_op,
485 BElementwiseOperation b_element_op,
486 CDEElementwiseOperation cde_element_op,
487 HElementwiseOperation h_element_op)
488 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
489 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
490 p_ds_grid_{},
491 p_workspace_e_grid_{nullptr},
492 p_workspace_mean_{nullptr},
493 p_workspace_var_{nullptr},
494 p_workspace_count_{nullptr},
495 p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
496 p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
497 p_h_grid_{static_cast<HDataType*>(p_h_grid)},
502 DeviceOp::MakeEHGridDescriptor_M_N<Sequence<true, true>, MPerBlock, NPerBlock>(
503 MRaw, NRaw, StrideH)},
508 MRaw, NRaw, StrideH)},
521 MRaw, NRaw, StrideH)},
523 GridwiseGemmWelford64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
525 GridwiseGemmWelford64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
527 GridwiseGemmWelford64::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)},
528 a_element_op_{a_element_op},
529 b_element_op_{b_element_op},
530 cde_element_op_{cde_element_op},
531 h_element_op_{h_element_op},
532 MRaw_{MRaw},
533 NRaw_{NRaw},
534 KRaw_{KRaw},
535 gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
536 epsilon_{static_cast<AccDataType>(epsilon)}
537 {
538 // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1.
541 MRaw, gemm_nblock_);
542
545 MRaw, gemm_nblock_);
546
551 MRaw, gemm_nblock_);
552
558
559 // populate pointer, desc for Ds
560 static_for<0, NumDTensor, 1>{}([&](auto i) {
561 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
562
563 // D pointer
564 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
565
566 // D desc
569 MRaw, NRaw, StrideDs[i]);
570 });
571
572 // populate desc for Ds/E/mean/var/count
573 if(get_warp_size() == 64)
574 {
575 if constexpr(NXdlPerWave64 > 0)
576 {
582 {
586
590
594
598 }
599 }
600 }
601 else
602 {
603 if constexpr(NXdlPerWave32 > 0)
604 {
610 {
614
618
622
626 }
627 }
628 }
629 }
630
631 void Print() const
632 {
633 std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
634 std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
636 [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
637 std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl;
638 std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl;
639 }
640
641 // private:
642 // pointers
643 const ADataType* p_a_grid_;
644 const BDataType* p_b_grid_;
650 const GammaDataType* p_gamma_grid_;
651 const BetaDataType* p_beta_grid_;
652 HDataType* p_h_grid_;
653
654 // tensor descriptors for problem definiton
667
668 // tensor descriptors for block/thread-wise copy
679
680 // block-to-e-tile map
682
683 // element-wise op
684 AElementwiseOperation a_element_op_;
685 BElementwiseOperation b_element_op_;
686 CDEElementwiseOperation cde_element_op_;
687 HElementwiseOperation h_element_op_;
688
693 AccDataType epsilon_;
694 };
695
696 // Invoker
697 struct Invoker : public BaseInvoker
698 {
700 template <typename GridwiseGemmWelford>
701 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
702 {
703 float avg_time = 0;
704
705 if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_,
710 {
711 throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
712 }
713 if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
714 arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
715 throw std::runtime_error("wrong! WorkSpace pointer has not been set");
716
717 index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_);
718
719 const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
720 const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
721 const auto K =
722 arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
723
724 auto launch_kernel = [&](auto has_main_k_block_loop) {
725 constexpr bool has_main_loop = has_main_k_block_loop.value;
726
727 const auto kernel_gemm_welford =
729 GridwiseGemmWelford,
730 ADataType, // TODO: distiguish A/B datatype
731 typename GridwiseGemmWelford::DsGridPointer,
732 EMeanVarDataType,
733 AElementwiseOperation,
734 BElementwiseOperation,
735 CDEElementwiseOperation,
736 typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1,
737 typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1,
738 typename GridwiseGemmWelford::
739 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
740 typename GridwiseGemmWelford::
741 EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
742 typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock,
743 typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock,
744 typename GridwiseGemmWelford::DefaultBlock2ETileMap,
745 has_main_loop>;
746
747 const auto kernel_welford_layernorm =
749 EMeanVarDataType,
750 HDataType,
751 GammaDataType,
752 BetaDataType,
753 AccDataType,
758 HElementwiseOperation>;
759
760 avg_time +=
761 launch_and_time_kernel(stream_config,
762 kernel_gemm_welford,
763 dim3(grid_size),
764 dim3(BlockSize),
765 0,
766 arg.p_a_grid_,
767 arg.p_b_grid_,
768 arg.p_ds_grid_,
769 static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
770 static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
771 static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
772 static_cast<int32_t*>(arg.p_workspace_count_),
773 arg.a_element_op_,
774 arg.b_element_op_,
775 arg.cde_element_op_,
783 arg.NRaw_);
784
785 index_t MBlockClusterLength =
787 index_t NBlockClusterLength =
789 grid_size = MBlockClusterLength * NBlockClusterLength;
790
791 index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
792 arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
793
794 avg_time += launch_and_time_kernel(
795 stream_config,
796 kernel_welford_layernorm,
797 dim3(grid_size),
798 dim3(BlockSize),
799 0,
800 static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
801 static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
802 static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
803 static_cast<const int32_t*>(arg.p_workspace_count_),
804 arg.p_gamma_grid_,
805 arg.p_beta_grid_,
806 arg.p_h_grid_,
813 numMeanVarCountBlockTileIteration_N,
814 NBlockClusterLength,
815 arg.epsilon_,
816 arg.h_element_op_);
817
818 return avg_time;
819 };
820
821 if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K))
822 {
823 return launch_kernel(integral_constant<bool, true>{});
824 }
825 else
826 {
827 return launch_kernel(integral_constant<bool, false>{});
828 }
829 }
833 // polymorphic
834 float Run(const BaseArgument* p_arg,
835 const StreamConfig& stream_config = StreamConfig{}) override
836 {
837 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
838 }
839 };
840
841 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
842 {
843 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
844
845 size_t workspace_size = 0;
846
847 int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
848
849 // workspace for welford intermediate mean
850 workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
851
852 // workspace for welford intermediate variance
853 workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64;
854
855 // workspace for welford intermediate count
856 workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64;
857
859 workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
860
861 return (workspace_size);
862 };
863
865 void* p_workspace,
866 const StreamConfig& = StreamConfig{}) const override
867 {
868 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
869
870 pArg_->p_workspace_ = p_workspace;
871
872 int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
873
874 // setup buffer used for intermediate welford mean
875 pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
876
877 index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
878 mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
879
880 // setup buffer used for intermediate welford varirance
881 pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
882
883 index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
884 variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
885
886 // setup buffer used for intermediate welford count
887 pArg_->p_workspace_count_ =
888 reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
889
890 index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
891 count_space_sz = math::integer_least_multiple(count_space_sz, 64);
892
894 pArg_->p_workspace_e_grid_ =
895 reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
896 else
897 pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
898 };
899
900 static bool IsSupportedArgument(const Argument& arg)
901 {
903 {
904 return false;
905 }
906 // check vector load/store
907 {
910
911 // check vector load of A
912 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
913 {
914 if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
915 {
916 return false;
917 }
918 }
919 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
920 {
921 // FIXME: not rigorous
922 if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
923 {
924 return false;
925 }
926 }
927 else
928 {
929 return false;
930 }
931
932 // check vector laod of B
933 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
934 {
935 if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
936 {
937 return false;
938 }
939 }
940 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
941 {
942 // FIXME: not rigorous
943 if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
944 {
945 return false;
946 }
947 }
948 else
949 {
950 return false;
951 }
952
953 // check vector load of Ds
954 // only support RowMajor for now
955 bool all_valid = true;
956
957 static_for<0, NumDTensor, 1>{}([&](auto i) {
958 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
959
960 if constexpr(!is_same_v<DLayout, Row>)
961 {
962 all_valid = false;
963 }
964 });
965
966 if(!all_valid)
967 {
968 return false;
969 }
970
971 // check vector store of E
972 // E and H only support RowMajor for now
974 {
975 if(arg.NRaw_ % PostShuffleScalarPerVector != 0 ||
979 {
980 return false;
981 }
982 }
983 else
984 {
985 return false;
986 }
987 }
988 if(get_warp_size() == 64)
989 {
990 if constexpr(NXdlPerWave64 > 0)
991 {
997 }
998 else
999 {
1000 return false;
1001 }
1002 }
1003 else
1004 {
1005 if constexpr(NXdlPerWave32 > 0)
1006 {
1008 arg.b_grid_desc_n_k_,
1011 arg.block_2_etile_map_);
1012 }
1013 else
1014 {
1015 return false;
1016 }
1017 }
1018 }
1019
1020 // polymorphic
1021 bool IsSupportedArgument(const BaseArgument* p_arg) override
1022 {
1023 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1024 }
1025
1026 static auto MakeArgument(const void* p_a,
1027 const void* p_b,
1028 std::array<const void*, NumDTensor> p_ds,
1029 const void* p_gamma,
1030 const void* p_beta,
1031 void* p_h,
1032 index_t MRaw,
1033 index_t NRaw,
1034 index_t KRaw,
1035 index_t StrideA,
1036 index_t StrideB,
1037 std::array<index_t, NumDTensor> StrideDs,
1038 index_t StrideH,
1039 double epsilon,
1040 AElementwiseOperation a_element_op,
1041 BElementwiseOperation b_element_op,
1042 CDEElementwiseOperation cde_element_op,
1043 HElementwiseOperation h_element_op)
1044 {
1045 return Argument{p_a,
1046 p_b,
1047 p_ds,
1048 p_gamma,
1049 p_beta,
1050 p_h,
1051 MRaw,
1052 NRaw,
1053 KRaw,
1054 StrideA,
1055 StrideB,
1056 StrideDs,
1057 StrideH,
1058 epsilon,
1059 a_element_op,
1060 b_element_op,
1061 cde_element_op,
1062 h_element_op};
1063 }
1064
1065 static auto MakeInvoker() { return Invoker{}; }
1066
1067 // polymorphic
1068 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
1069 const void* p_b,
1070 std::array<const void*, NumDTensor> p_ds,
1071 const void* p_gamma,
1072 const void* p_beta,
1073 void* p_h,
1074 index_t MRaw,
1075 index_t NRaw,
1076 index_t KRaw,
1077 index_t StrideA,
1078 index_t StrideB,
1079 std::array<index_t, NumDTensor> StrideDs,
1080 index_t StrideH,
1081 double epsilon,
1082 AElementwiseOperation a_element_op,
1083 BElementwiseOperation b_element_op,
1084 CDEElementwiseOperation cde_element_op,
1085 HElementwiseOperation h_element_op) override
1086 {
1087 return std::make_unique<Argument>(p_a,
1088 p_b,
1089 p_ds,
1090 p_gamma,
1091 p_beta,
1092 p_h,
1093 MRaw,
1094 NRaw,
1095 KRaw,
1096 StrideA,
1097 StrideB,
1098 StrideDs,
1099 StrideH,
1100 epsilon,
1101 a_element_op,
1102 b_element_op,
1103 cde_element_op,
1104 h_element_op);
1105 }
1106
1107 // polymorphic
1108 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1109 {
1110 return std::make_unique<Invoker>(Invoker{});
1111 }
1112
1113 // polymorphic
1114 std::string GetTypeString() const override
1115 {
1116 auto str = std::stringstream();
1117
1118 std::map<LoopScheduler, std::string> LoopSchedToString{
1119 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
1120
1121 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
1122 {PipelineVersion::v2, "v2"}};
1123
1124 // clang-format off
1125 str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle"
1126 << "<"
1127 << BlockSize << ", "
1128 << MPerBlock << ", "
1129 << NPerBlock << ", "
1130 << KPerBlock << ", "
1131 << AK1 << ", "
1132 << BK1 << ", "
1133 << getGemmSpecializationString(GemmSpec) << ", "
1134 << PostShuffleThreadClusterSize_M_N::At(I0) << ", "
1135 << PostShuffleThreadClusterSize_M_N::At(I1) << ", "
1136 << LayernormThreadClusterSize_M_N::At(I0) << ", "
1137 << LayernormThreadClusterSize_M_N::At(I1) << ", "
1138 << LayernormThreadSliceSize_M
1139 << ">"
1140 << " LoopScheduler: "
1141 << LoopSchedToString[LoopSched] << ", "
1142 << "PipelineVersion: "
1143 << PipelineVersionToString[PipelineVer];
1144 // clang-format on
1145
1146 return str.str();
1147 }
1148}; // namespace device
1149
1150} // namespace device
1151} // namespace tensor_operation
1152} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition utility/math.hpp:13
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EMeanVarDataType *__restrict__ p_e_grid, EMeanVarDataType *__restrict__ p_welford_mean_grid, EMeanVarDataType *__restrict__ p_welford_var_grid, int32_t *__restrict__ p_welford_count_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock mean_var_grid_desc_mblock_mperblock_nblock, const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, const Block2ETileMap block_2_etile_map, index_t NRaw)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:42
__global__ void kernel_welford_layernorm2d_second_half(const EMeanVarDataType *__restrict__ p_e_grid, const EMeanVarDataType *__restrict__ p_in_welford_mean_grid, const EMeanVarDataType *__restrict__ p_in_welford_var_grid, const int32_t *__restrict__ p_in_welford_count_grid, const GammaDataType *__restrict__ p_gamma_grid, const BetaDataType *__restrict__ p_beta_grid, HDataType *__restrict__ p_h_grid, const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n, const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n, index_t numMeanVarCountBlockTileIteration_N, index_t NBlockClusterLength, ComputeDataType epsilon, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:87
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp:84
Definition gridwise_welford_second_half_layernorm2d.hpp:42
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:469
GridwiseGemmWelford64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:674
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:684
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:686
index_t gemm_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:692
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:643
void * p_workspace_var_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:648
EHGridDesc_M_N gemm_e_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:658
EHGridDesc_M_N h_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:666
GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:660
GridwiseGemmWelford64::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:669
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, const void *p_gamma_grid, const void *p_beta_grid, void *p_h_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:470
index_t MRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:689
const BetaDataType * p_beta_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:651
void Print() const
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:631
HElementwiseOperation h_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:687
EHGridDesc_M_N layernorm_e_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:659
index_t NRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:690
AccDataType epsilon_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:693
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:657
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:685
GridwiseGemmWelford64::CountGridDescriptor_MBlock_MPerBlock_NBlock gemm_count_grid_desc_mblock_mperblock_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:678
const GammaDataType * p_gamma_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:650
index_t KRaw_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:691
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:663
GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:661
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:656
void * p_workspace_count_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:649
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:644
GammaBetaGridDesc_N beta_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:665
GammaBetaGridDesc_N gamma_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:664
GridwiseGemmWelford64::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock gemm_mean_var_grid_desc_mblock_mperblock_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:676
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:681
GridwiseGemmWelford64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:645
void * p_workspace_mean_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:647
GridwiseGemmWelford64::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:670
HDataType * p_h_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:652
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:655
void * p_workspace_e_grid_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:646
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:662
GridwiseGemmWelford64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:672
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:698
GridwiseGemmWelford32 GridwiseGemm32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:830
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:701
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:699
GridwiseGemmWelford64 GridwiseGemm64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:831
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:834
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:250
GridwiseGemmWelfordBase< NXdlPerWave32 > GridwiseGemmWelford32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:442
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:261
static auto MakeDescriptor_X(index_t X)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:361
static constexpr index_t LayernormGammaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:266
GridwiseWelfordSecondHalfLayernorm2d< EMeanVarDataType, HDataType, GammaDataType, BetaDataType, AccDataType, EHGridDesc_M_N, LayernormMeanVarGridDesc_M_NBlock, LayernormCountGridDesc_M_NBlock, GammaBetaGridDesc_N, HElementwiseOperation, BlockSize, LayernormThreadClusterSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I1), LayernormThreadSliceSize_M, LayernormThreadSliceSize_N, LayernormESrcVectorSize, LayernormHDstVectorSize, LayernormGammaSrcVectorSize, LayernormBetaSrcVectorSize > GridwiseWelfordLayernorm
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:446
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:262
DeviceGemmMultipleDLayernorm_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:257
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1021
std::string GetTypeString() const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1114
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:368
typename GridwiseGemmWelford64::DefaultBlock2ETileMap Block2ETileMap
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:444
Sequence< LayernormThreadClusterSize_M_N::At(0) *LayernormThreadSliceSize_M, LayernormThreadClusterSize_M_N::At(1) *LayernormThreadSliceSize_N > LayernormBlockTileSize_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:270
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormMeanVarGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:378
static constexpr auto I0
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:274
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EHGridDesc_M_N, GemmMeanVarGridDesc_M_NBlock, GemmCountGridDesc_M_NBlock, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, PostShuffleThreadClusterSize_M_N, PostShuffleScalarPerVector, LoopSched, PipelineVer > GridwiseGemmWelfordBase
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:392
static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:343
static auto MakeInvoker()
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1065
decltype(MakeCountDescriptor_M_N< Sequence< true, false >, MPerBlock, NPerBlock >(1, 1)) GemmCountGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:375
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:278
static constexpr index_t LayernormESrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:268
static constexpr auto I1
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:275
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:264
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:369
static constexpr index_t LayernormThreadSliceSize_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:269
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:281
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:299
decltype(MakeEHGridDescriptor_M_N< Sequence< true, true >, 1, 1 >(1, 1, 1)) EHGridDesc_M_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:389
static constexpr auto I2
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:276
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:318
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1026
decltype(MakeDescriptor_X< LayernormBlockTileSize_M_N::At(1)>(1)) GammaBetaGridDesc_N
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:388
static constexpr index_t LayernormBetaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:267
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:900
HLayout ELayout
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:258
static auto MakeCountDescriptor_M_N(index_t M, index_t N)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:351
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1108
GridwiseGemmWelfordBase< math::max(NXdlPerWave64, 1)> GridwiseGemmWelford64
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:441
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:864
decltype(MakeCountDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormCountGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:383
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:841
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op) override
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:1068
static constexpr index_t LayernormHDstVectorSize
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:265
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:367
decltype(MakeMeanVarDescriptor_M_N< Sequence< true, false >, MPerBlock, NPerBlock >(1, 1)) GemmMeanVarGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:372
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp:326
Definition device_gemm_multiple_d_layernorm.hpp:40
Definition matrix_padder.hpp:180