gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
16
17namespace ck {
18
19// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
20// kernel function Blockers:
21// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
22// two lds chunks.
23// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
24// buffer when we declare __shared__ inside blkgemmpipe
25template <typename GridwiseGemm,
26 bool HasMainKBlockLoop,
27 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28 index_t MinimumOccupancy = 1,
30__global__ void
31#if CK_USE_LAUNCH_BOUNDS
32__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
33#endif
34 // __attribute__((amdgpu_waves_per_eu(1, 1)))
35 kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
36{
37#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
38 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
39 {
40 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
41
42 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
43 karg.p_as_grid,
44 karg.p_bs_grid,
45 karg.p_ds_grid,
46 karg.p_c_grid,
47 p_shared,
48 karg,
49 karg.a_element_op,
50 karg.b_element_op,
51 karg.c_element_op);
52 }
53#else
54 ignore = karg;
55#endif // end of if (defined(__gfx9__))
56}
57
58template <typename GridwiseGemm,
59 bool HasMainKBlockLoop,
60 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
61 index_t MinimumOccupancy = 1,
63__global__ void
64#if CK_USE_LAUNCH_BOUNDS
65__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
66#endif
67 // __attribute__((amdgpu_waves_per_eu(1, 1)))
68 kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
69{
70#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
71 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
72 {
73 // Pass two lds pointer is the key to tell compiler that ds_read/write
74 // operate on different lds chunk at same time without order dependecy
75 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
76 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
77
78 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
79 karg.p_as_grid,
80 karg.p_bs_grid,
81 karg.p_ds_grid,
82 karg.p_c_grid,
83 p_shared_0,
84 p_shared_1,
85 karg,
86 karg.a_element_op,
87 karg.b_element_op,
88 karg.c_element_op);
89 }
90#else
91 ignore = karg;
92#endif // end of if (defined(__gfx9__))
93}
94
95template <typename ALayout,
96 typename BLayout,
97 typename CLayout,
98 typename AsDataType,
99 typename BsDataType,
100 typename AccDataType,
101 typename CShuffleDataType,
102 typename DsDataType,
103 typename CDataType,
104 typename AElementwiseOperation,
105 typename BElementwiseOperation,
106 typename CElementwiseOperation,
108 index_t BlockSize,
109 index_t MPerBlock,
110 index_t NPerBlock,
111 index_t KPerBlock,
112 index_t AK1Value,
113 index_t BK1Value,
114 index_t MPerXdl,
115 index_t NPerXdl,
116 index_t MXdlPerWave,
117 index_t NXdlPerWave,
118 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
119 typename ABlockTransferThreadClusterArrangeOrder,
120 typename ABlockTransferSrcAccessOrder,
121 index_t ABlockTransferSrcVectorDim,
122 index_t ABlockTransferSrcScalarPerVector,
123 index_t ABlockTransferDstScalarPerVector_AK1,
124 bool AThreadTransferSrcResetCoordinateAfterRun,
125 index_t ABlockLdsExtraM,
126 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
129 index_t BBlockTransferSrcVectorDim,
130 index_t BBlockTransferSrcScalarPerVector,
131 index_t BBlockTransferDstScalarPerVector_BK1,
132 bool BThreadTransferSrcResetCoordinateAfterRun,
133 index_t BBlockLdsExtraN,
134 index_t CShuffleMXdlPerWavePerShuffle,
135 index_t CShuffleNXdlPerWavePerShuffle,
136 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
137 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
140 typename ComputeTypeA = CDataType,
141 typename ComputeTypeB = ComputeTypeA>
143{
144 static constexpr auto I0 = Number<0>{};
145 static constexpr auto I1 = Number<1>{};
146 static constexpr auto I2 = Number<2>{};
147 static constexpr auto I3 = Number<3>{};
148 static constexpr auto I4 = Number<4>{};
149 static constexpr auto I5 = Number<5>{};
150 static constexpr auto I6 = Number<6>{};
151 static constexpr auto I7 = Number<7>{};
152
153 using LDSTypeA = ComputeTypeA;
154 using LDSTypeB = ComputeTypeB;
155
156 // K1 should be Number<...>
157 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
158 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
159 static constexpr auto AK1Number = Number<AK1Value>{};
160 static constexpr auto BK1Number = Number<BK1Value>{};
161
162 static constexpr index_t NumATensor = AsDataType::Size();
163 static constexpr index_t NumBTensor = BsDataType::Size();
164 static constexpr index_t NumDTensor = DsDataType::Size();
165
166 static constexpr auto MakeAsGridPointer()
167 {
168 return generate_tuple(
169 [&](auto i) {
170 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
171
172 return static_cast<const ADataType_*>(nullptr);
173 },
175 }
176
177 static constexpr auto MakeBsGridPointer()
178 {
179 return generate_tuple(
180 [&](auto i) {
181 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
182
183 return static_cast<const BDataType_*>(nullptr);
184 },
186 }
187
188 static constexpr auto MakeDsGridPointer()
189 {
190 return generate_tuple(
191 [&](auto i) {
192 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
193
194 return static_cast<const DDataType*>(nullptr);
195 },
197 }
198
199 using AsGridPointer = decltype(MakeAsGridPointer());
200 using BsGridPointer = decltype(MakeBsGridPointer());
201 using DsGridPointer = decltype(MakeDsGridPointer());
202
203 static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
204 static constexpr bool is_single_rate_mfma =
206 lcm_AK1_BK1 <= 4) ||
209 lcm_AK1_BK1 < 32))
210 ? true
211 : false;
212 static constexpr auto is_scale_mfma = false;
213 static constexpr index_t KPack =
215 MfmaSelector<ComputeTypeA,
216 MPerXdl,
217 NPerXdl,
218 ComputeTypeB,
220 is_scale_mfma>::selected_mfma.k_per_blk);
221
223
224 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
225 {
226 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
227 }
228
229 __host__ static auto CalculateMPadded(index_t M)
230 {
231 return math::integer_least_multiple(M, MPerBlock);
232 }
233
234 __host__ static auto CalculateNPadded(index_t N)
235 {
236 return math::integer_least_multiple(N, NPerBlock);
237 }
238
239 __host__ static auto CalculateKPadded(index_t K)
240 {
241 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
242 }
243
244 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
245 {
246 auto K_t = K_Batch * KPerBlock;
247 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
248 }
249
250 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
251 {
252 auto K_t = K_Batch * KPerBlock;
253 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
254 }
255
256 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
257 {
258 auto K_t = K_Batch * KPerBlock;
259 return (K + K_t - 1) / K_t * KPerBlock;
260 }
261
262 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
263 {
264 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
265 auto K_t = K_Batch * KReadVec;
266 return (K + K_t - 1) / K_t * KReadVec;
267 }
268
269 __host__ static auto CalculateMBlock(index_t M)
270 {
271 return math::integer_divide_ceil(M, MPerBlock);
272 }
273
274 __host__ static auto CalculateNBlock(index_t N)
275 {
276 return math::integer_divide_ceil(N, NPerBlock);
277 }
278
279 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
280 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
281 {
282 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
283 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
284
286 TileDesc_K0_MN_K1{},
292 }
293
294 __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
295 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
296 {
297 const auto a_grid_desc_mraw_kraw = [&]() {
299 {
300 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
301 }
303 {
304 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
305 }
306 }();
307
308 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
309
310 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
311 GemmSpec == GemmSpecialization::MNKPadding)
312 {
313 // pad both M and K
314 const auto a_grid_desc_m_k =
315 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
317 make_right_pad_transform(K, KPad - K)),
320
321 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
322 a_grid_desc_m_k,
327
328 return a_grid_desc_ak0_m_ak1;
329 }
330 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
331 GemmSpec == GemmSpecialization::MNPadding)
332 {
333 // pad M, but not K
334 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
335 a_grid_desc_mraw_kraw,
337 make_right_pad_transform(M, MPad - M)),
340
341 return a_grid_desc_ak0_m_ak1;
342 }
343 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
344 GemmSpec == GemmSpecialization::NKPadding)
345 {
346 // pad K, but not M
347 const auto a_grid_desc_m_k = transform_tensor_descriptor(
348 a_grid_desc_mraw_kraw,
352
353 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
354 a_grid_desc_m_k,
359
360 return a_grid_desc_ak0_m_ak1;
361 }
362 else
363 {
364 // not pad M or K
365 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
366 a_grid_desc_mraw_kraw,
371
372 return a_grid_desc_ak0_m_ak1;
373 }
374 }
375
376 __host__ __device__ static auto
378 const index_t MPad,
379 const index_t K,
380 const index_t KPad,
381 const std::array<index_t, NumATensor>& StrideAs,
382 const index_t AK0)
383 {
384 return generate_tuple(
385 [&](auto i) {
386 return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0);
387 },
389 }
390
391 __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
392 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
393 {
394 const auto b_grid_desc_nraw_kraw = [&]() {
396 {
397 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
398 }
400 {
401 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
402 }
403 }();
404
405 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
406
407 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
408 GemmSpec == GemmSpecialization::MNKPadding)
409 {
410 // pad both N and K
411 const auto b_grid_desc_n_k =
412 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
414 make_right_pad_transform(K, KPad - K)),
417
418 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
419 b_grid_desc_n_k,
424
425 return b_grid_desc_bk0_n_bk1;
426 }
427 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
428 GemmSpec == GemmSpecialization::MNPadding)
429 {
430 // pad N, but not K
431 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
432 b_grid_desc_nraw_kraw,
434 make_right_pad_transform(N, NPad - N)),
437
438 return b_grid_desc_bk0_n_bk1;
439 }
440 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
441 GemmSpec == GemmSpecialization::MKPadding)
442 {
443 // pad K, but not N
444 const auto b_grid_desc_n_k = transform_tensor_descriptor(
445 b_grid_desc_nraw_kraw,
449
450 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
451 b_grid_desc_n_k,
456
457 return b_grid_desc_bk0_n_bk1;
458 }
459 else
460 {
461 // not pad N or K
462 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
463 b_grid_desc_nraw_kraw,
468
469 return b_grid_desc_bk0_n_bk1;
470 }
471 }
472
473 __host__ __device__ static auto
475 const index_t KPad,
476 const index_t N,
477 const index_t NPad,
478 const std::array<index_t, NumBTensor>& StrideBs,
479 const index_t BK0)
480 {
481 return generate_tuple(
482 [&](auto i) {
483 return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0);
484 },
486 }
487
488 template <typename ABlockDesc_AK0_M_AK1>
489 __host__ __device__ static constexpr auto
490 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
491 {
492 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
493
494 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
495 }
496
497 template <typename BBlockDesc_BK0_N_BK1>
498 __host__ __device__ static constexpr auto
499 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
500 {
501 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
502
503 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
504 }
505
506 __host__ __device__ static auto
508 {
509 const auto c_grid_desc_mraw_nraw = [&]() {
511 {
512 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
513 }
515 {
516 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
517 }
518 }();
519
520 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
521
522 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
523 GemmSpec == GemmSpecialization::MNKPadding)
524 {
525 // pad M and N
526 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
528 make_right_pad_transform(N, NPad - N)),
531 }
532 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
533 GemmSpec == GemmSpecialization::MKPadding)
534 {
535 // pad M, but not N
537 c_grid_desc_mraw_nraw,
541 }
542 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
543 GemmSpec == GemmSpecialization::NKPadding)
544 {
545 // pad N, but not M
547 c_grid_desc_mraw_nraw,
551 }
552 else
553 {
554 // not pad M or N
555 return c_grid_desc_mraw_nraw;
556 }
557 }
558
559 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
560 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
561 {
562 return generate_tuple(
563 [&](auto i) { return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); },
565 }
566
567 struct Problem
568 {
569 __host__ Problem(index_t M_,
570 index_t N_,
571 index_t K_,
572 std::array<index_t, NumATensor> StrideAs_,
573 std::array<index_t, NumBTensor> StrideBs_,
574 std::array<index_t, NumDTensor> StrideDs_,
575 index_t StrideC_,
576 index_t KBatch_)
577 : M{M_},
578 N{N_},
579 K{K_},
580 StrideAs{StrideAs_},
581 StrideBs{StrideBs_},
582 StrideDs{StrideDs_},
583 StrideC{StrideC_},
584 KBatch{KBatch_},
587 KRead{CalculateKRead(K_, KBatch_)},
588 KPadded{CalculateKPadded(K_, KBatch_)},
589 AK0{CalculateAK0Padded(K_, KBatch_)},
590 BK0{CalculateBK0Padded(K_, KBatch_)},
593 {
594 }
595
596 __host__ void Print() const
597 {
598 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
599 << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
600 << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
601 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
602 << std::endl;
603 }
604
605 index_t M;
606 index_t N;
607 index_t K;
608
609 std::array<index_t, NumATensor> StrideAs;
610 std::array<index_t, NumBTensor> StrideBs;
611 std::array<index_t, NumDTensor> StrideDs;
613
619 index_t AK0;
620 index_t BK0;
623 };
624
625 // Argument
627 {
628 __host__ Argument(std::array<const void*, NumATensor> p_as_grid_,
629 std::array<const void*, NumBTensor> p_bs_grid_,
630 std::array<const void*, NumDTensor> p_ds_grid_,
631 void* p_c_grid_,
632 index_t M_,
633 index_t N_,
634 index_t K_,
635 std::array<index_t, NumATensor> StrideAs_,
636 std::array<index_t, NumBTensor> StrideBs_,
637 std::array<index_t, NumDTensor> StrideDs_,
638 index_t StrideC_,
639 index_t k_batch_,
640 AElementwiseOperation a_element_op_,
641 BElementwiseOperation b_element_op_,
642 CElementwiseOperation c_element_op_)
643 : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideC_, k_batch_},
644 p_as_grid{},
645 p_bs_grid{},
646 p_ds_grid{},
647 p_c_grid{static_cast<CDataType*>(p_c_grid_)},
651
652 {
653 // populate pointer, desc for As
654 static_for<0, NumATensor, 1>{}([&](auto i) {
655 using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
656
657 // A pointer
658 p_as_grid(i) = static_cast<const ADataType_*>(p_as_grid_[i]);
659 });
660
661 // populate pointer, desc for Bs
662 static_for<0, NumBTensor, 1>{}([&](auto i) {
663 using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
664
665 // B pointer
666 p_bs_grid(i) = static_cast<const BDataType_*>(p_bs_grid_[i]);
667 });
668
669 // populate pointer, desc for Ds
670 static_for<0, NumDTensor, 1>{}([&](auto i) {
671 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
672
673 // D pointer
674 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
675 });
676 }
677
681 CDataType* p_c_grid;
682
683 const AElementwiseOperation a_element_op;
684 const BElementwiseOperation b_element_op;
685 const CElementwiseOperation c_element_op;
686 };
687
688 struct SplitKBatchOffset
689 {
690 __device__ SplitKBatchOffset(Argument& karg)
691 {
693 {
694 a_k_split_offset = blockIdx.z * karg.KRead;
695 }
697 {
698 a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
699 }
700
702 {
703 b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
704 }
706 {
707 b_k_split_offset = blockIdx.z * karg.KRead;
708 }
709
710 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
711 {
712 karg.K = karg.KRead;
713 }
714 else
715 {
716 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
717 }
718 }
719
722 };
723
724#if 0
725 struct SplitKBatchOffsetMultiABD
726 {
727 __device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid,
728 BsGridPointer& p_bs_grid,
729 Argument& karg)
730 {
731 static_for<0, NumATensor, 1>{}([&](auto i) {
732 using ALayout_ = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
733 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout_>)
734 {
735 as_k_split_offset[i] = blockIdx.z * karg.KRead;
736 }
737 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout_>)
738 {
739 as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i];
740 }
741
742 p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i];
743 });
744
745 static_for<0, NumBTensor, 1>{}([&](auto i) {
746 using BLayout_ = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
747 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout_>)
748 {
749 bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i];
750 }
751 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout_>)
752 {
753 bs_k_split_offset[i] = blockIdx.z * karg.KRead;
754 }
755
756 p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i];
757 });
758
759 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
760 {
761 karg.K = karg.KRead;
762 }
763 else
764 {
765 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
766 }
767 }
768
769 AsGridPointer p_as_grid_;
770 BsGridPointer p_bs_grid_;
771 std::array<index_t, NumATensor> as_k_split_offset;
772 std::array<index_t, NumBTensor> bs_k_split_offset;
773 };
774#endif
775
776 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
777 {
778 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
779 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
780 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
781 // A matrix in LDS memory, dst of blockwise copy
782 if constexpr(ABlockLdsExtraM)
783 {
787 }
788 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
789 // in some cases.
791 {
792 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1
793 ? 1
794 : 32 * 4 / KPerBlock / sizeof(LDSTypeA);
795 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
797 AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
799
800 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
801 a_lds_block_desc,
807
808 constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
809 a_lds_block_desc_permuted,
815
816 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
817 a_lds_block_desc_ak0_mldslayer_m_ak1,
824
825 return a_lds_block_desc_ak0_m_ak1;
826 }
827 else // ColumnMajor A
828 {
829 // kfold and mpair dimension is not always required.
830 // more dimension in merge_transform increase the difficulty of generating immarg offset
831 // for compiler.
832 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
833 constexpr auto M1 = MPerBlock / M0;
834
835 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
836 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
837 constexpr auto KThreadRead = WaveSize / MPerXdl;
838 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
839
840 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
841 ? 1
842 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
843 constexpr auto KThreadReadPerm =
844 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
845 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
846 : KThreadRead;
847
848 // 1<=mpair<=n0
849 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
850 ? 1
851 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
852 ? M0
853 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
854
855 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
859 Number<kfold * M0 / mpair>{},
861 AK1Number));
862
863 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
864 a_lds_block_desc,
869 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
876
877 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
878 a_lds_block_desc_permuted,
887 Sequence<1>{},
888 Sequence<2>{},
889 Sequence<3>{},
890 Sequence<4>{},
891 Sequence<5>{}),
893 Sequence<2>{},
896 Sequence<6>{},
897 Sequence<7>{}));
898
899 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
900 a_lds_block_desc_unmerged,
903 Number<KThreadWrite / kfold / KThreadReadPerm>{},
911
912 return a_lds_block_desc_ak0_m_ak1;
913 }
914 }
915
916 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
917 {
918 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
919 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
920 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
921 // B matrix in LDS memory, dst of blockwise copy
922 if constexpr(BBlockLdsExtraN)
923 {
927 }
929 {
930 // NLdsLayer * K0 as logical Bank
931 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1
932 ? 1
933 : 32 * 4 / KPerBlock / sizeof(LDSTypeB);
934 ;
935 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
937 BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
939
940 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
941 b_lds_block_desc,
947
948 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
949 b_lds_block_desc_permuted,
955
956 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
957 b_lds_block_desc_bk0_nldslayer_n_bk1,
964
965 return b_lds_block_desc_bk0_n_bk1;
966 }
967 else // RowMajor B
968 {
969 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
970 constexpr auto N1 = NPerBlock / N0;
971
972 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
973 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
974 constexpr auto KThreadRead = WaveSize / NPerXdl;
975 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
976
977 constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128)
978 ? 1
979 : 128 / (BK1Number * N0 * sizeof(LDSTypeB));
980 constexpr auto KThreadReadPerm =
981 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
982 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
983 : KThreadRead;
984
985 // 1<=npair<=n0
986 constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128)
987 ? 1
988 : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0
989 ? N0
990 : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB)));
991
992 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
996 Number<kfold * N0 / npair>{},
998 BK1Number));
999
1000 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1001 b_lds_block_desc,
1002 make_tuple(
1006 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1009 make_tuple(
1011 make_tuple(
1013
1014 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1015 b_lds_block_desc_permuted,
1016 make_tuple(
1024 Sequence<1>{},
1025 Sequence<2>{},
1026 Sequence<3>{},
1027 Sequence<4>{},
1028 Sequence<5>{}),
1030 Sequence<2>{},
1033 Sequence<6>{},
1034 Sequence<7>{}));
1035
1036 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1037 b_lds_block_desc_unmerged,
1040 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1041 Number<kfold>{},
1048
1049 return b_lds_block_desc_bk0_n_bk1;
1050 }
1051 }
1052
1054 {
1055 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1056 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1057
1058 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1060 make_tuple(I1,
1062 I1,
1064
1065 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1066 }
1067
1070 BlkGemmPipelineVer,
1071 BlkGemmPipeSched,
1072 BlockSize,
1073 LDSTypeA,
1074 LDSTypeB,
1075 ComputeTypeA,
1076 AccDataType,
1083 ABlockTransferSrcScalarPerVector,
1084 BBlockTransferSrcScalarPerVector,
1085 MPerBlock,
1086 NPerBlock,
1087 KPerBlock,
1088 MPerXdl,
1089 NPerXdl,
1090 MXdlPerWave,
1091 NXdlPerWave,
1092 KPack>())>;
1093
1094 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1095 {
1096 // LDS allocation for A and B: be careful of alignment
1097 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1098 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1099
1100 // lds max alignment
1101 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1102
1103 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1104 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1105
1106 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1107 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1108
1109 // LDS allocation for C shuffle in LDS
1110 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1112
1113 constexpr auto c_block_size =
1114 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1115
1116 return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
1117 b_block_space_size_aligned * sizeof(LDSTypeB)),
1118 c_block_size * sizeof(CShuffleDataType));
1119 }
1120
1122
1123 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1124 __host__ static constexpr bool CheckValidity(const Argument& karg)
1125 {
1126 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1127 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1128 "Invalid tuning param!");
1129
1130 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1134 {
1135 if(!(karg.M % MPerBlock == 0))
1136 {
1137 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1138 {
1139 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1140 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1141 << std::endl;
1142 }
1143 return false;
1144 }
1145 }
1146
1147 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1151 {
1152 if(!(karg.N % NPerBlock == 0))
1153 {
1154 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1155 {
1156 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1157 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1158 << std::endl;
1159 }
1160 return false;
1161 }
1162 }
1163
1164 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1168 {
1169
1170 auto K_t = karg.KBatch * KPerBlock;
1171 if(!(karg.K % K_t == 0))
1172 {
1173 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1174 {
1175 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1176 << karg.K << " " << __FILE__ << ":" << __LINE__
1177 << ", in function: " << __func__ << std::endl;
1178 }
1179 return false;
1180 }
1181 }
1182 else
1183 {
1184 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1185 auto K_t = karg.KBatch * KReadVec;
1186 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1187 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1188 {
1189 return false;
1190 }
1191 }
1192
1194 {
1195 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1196 {
1197 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1198 {
1199 std::cout << "Arg K (" << karg.K
1200 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1201 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1202 << __LINE__ << ", in function: " << __func__ << std::endl;
1203 }
1204 return false;
1205 }
1206 }
1207 else
1208 {
1209 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1210 {
1211 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1212 {
1213 std::cout << "Arg M (" << karg.M
1214 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1215 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1216 << __LINE__ << ", in function: " << __func__ << std::endl;
1217 }
1218 return false;
1219 }
1220 }
1221
1223 {
1224 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1225 {
1226 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1227 {
1228 std::cout << "Arg N (" << karg.N
1229 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1230 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1231 << __LINE__ << ", in function: " << __func__ << std::endl;
1232 }
1233 return false;
1234 }
1235 }
1236 else
1237 {
1238 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1239 {
1240 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1241 {
1242 std::cout << "Arg K (" << karg.K
1243 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1244 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1245 << __LINE__ << ", in function: " << __func__ << std::endl;
1246 }
1247 return false;
1248 }
1249 }
1250
1252 {
1253 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1254 {
1255 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1256 {
1257 std::cout << "Arg N (" << karg.N
1258 << ") value is not a multiple of "
1259 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1260 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1261 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1262 << std::endl;
1263 }
1264 return false;
1265 }
1266 }
1267 else
1268 {
1269 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1270 {
1271 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1272 {
1273 std::cout << "Arg M (" << karg.M
1274 << ") value is not a multiple of "
1275 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1276 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1277 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1278 << std::endl;
1279 }
1280 return false;
1281 }
1282 }
1283
1284 // check gridwise gemm pipeline
1285 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1286
1287 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1288 {
1289 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1290 {
1291 return false;
1292 }
1293 }
1294
1295 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1296 return true;
1297 }
1298
1299 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1300 {
1301 const index_t num_loop = K / KPerBlock;
1302
1303 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1304 }
1305
1306 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1307 {
1308 const index_t num_loop = K / KPerBlock;
1309
1310 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1311 }
1312
1313 template <typename CGridDesc>
1315 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1316 {
1317 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1318 c_grid_desc_m_n,
1323
1324 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1325 }
1326
1327 template <typename DsGridDesc>
1329 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
1330 {
1331 return generate_tuple(
1332 [&](auto i) {
1334 ds_grid_desc_m_n[i], MBlock, NBlock);
1335 },
1337 }
1338
1339 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
1340
1341 // return block_id to C matrix tile idx (m0, n0) mapping
1342 // if arch = gfx942
1344 // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1345
1346 template <bool HasMainKBlockLoop,
1347 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1348 TailNumber TailNum = TailNumber::Odd>
1349 __device__ static void Run(AsGridPointer& p_as_grid,
1350 BsGridPointer& p_bs_grid,
1351 DsGridPointer& p_ds_grid,
1352 CDataType* p_c_grid,
1353 void* p_shared,
1354 const Problem& problem,
1355 const AElementwiseOperation& a_element_op,
1356 const BElementwiseOperation& b_element_op,
1357 const CElementwiseOperation& c_element_op)
1358 {
1359 // std::array<index_t, NumATensor> StrideAs = {problem.StrideA};
1360 // std::array<index_t, NumBTensor> StrideBs = {problem.StrideB};
1361
1362 // AsGridPointer p_as_grid;
1363 // BsGridPointer p_bs_grid;
1364 // DsGridPointer p_ds_grid;
1365
1366 // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1367 // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1368 // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1369 // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1370 const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
1371 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
1372 const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
1373 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
1374 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1375 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1376
1377 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1379 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1380
1381 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1382 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1383
1384#if 0
1385 static_for<0, NumDTensor, 1>{}([&](auto j) {
1386 ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N(
1387 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]);
1388 });
1389#endif
1390
1391 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1393 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1394
1395 // const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1396 // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1397 // const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1398 // p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1399
1400 const auto as_grid_buf = generate_tuple(
1401 [&](auto i) {
1403 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
1404 },
1406
1407 const auto bs_grid_buf = generate_tuple(
1408 [&](auto i) {
1410 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
1411 },
1413
1415 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1416
1417 const auto ds_grid_buf = generate_tuple(
1418 [&](auto i) {
1420 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1421 },
1423
1424 // divide block work by [M, N]
1425 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1426
1427 const auto block_work_idx =
1428 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1429
1430 if(!block_2_ctile_map.ValidCTileIndex(
1431 block_work_idx,
1432 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1433 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1434 {
1435 return;
1436 }
1437
1438 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1439 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1440
1441 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1442 const index_t m_block_data_idx_on_grid =
1443 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1444
1445 const index_t n_block_data_idx_on_grid =
1446 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1447
1448 // lds max alignment
1449 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1450
1451 // A matrix in LDS memory, dst of blockwise copy
1452 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1453
1454 // B matrix in LDS memory, dst of blockwise copy
1455 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1456
1457#if 0
1458 // A matrix blockwise copy
1459 auto a_blockwise_copy =
1461 AElementwiseOperation,
1465 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1466 ABlockTransferThreadClusterArrangeOrder,
1467 ADataType,
1468 ADataType,
1469 decltype(a_grid_desc_ak0_m_ak1),
1470 decltype(a_block_desc_ak0_m_ak1),
1471 ABlockTransferSrcAccessOrder,
1473 ABlockTransferSrcVectorDim,
1474 2,
1475 ABlockTransferSrcScalarPerVector,
1476 ABlockTransferDstScalarPerVector_AK1,
1477 1,
1478 1,
1479 AThreadTransferSrcResetCoordinateAfterRun,
1480 true,
1481 BlockwiseGemmPipe::GlobalBufferNum>(
1482 a_grid_desc_ak0_m_ak1,
1483 make_multi_index(0, m_block_data_idx_on_grid, 0),
1484 a_element_op,
1485 a_block_desc_ak0_m_ak1,
1486 make_multi_index(0, 0, 0),
1488#else
1489 const auto idx_as_block_begin =
1490 generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
1492
1493 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
1495 AsDataType,
1497 decltype(as_grid_desc_ak0_m_ak1),
1498 decltype(tie(a_block_desc_ak0_m_ak1)),
1499 AElementwiseOperation,
1502 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1503 ABlockTransferThreadClusterArrangeOrder,
1504 ABlockTransferSrcAccessOrder,
1506 ABlockTransferSrcVectorDim,
1507 2,
1508 ABlockTransferSrcScalarPerVector,
1509 ABlockTransferDstScalarPerVector_AK1,
1512 BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
1513 idx_as_block_begin,
1514 tie(a_block_desc_ak0_m_ak1),
1515 make_tuple(make_multi_index(0, 0, 0)),
1516 a_element_op};
1517#endif
1518
1519#if 0
1520 // B matrix blockwise copy
1521 auto b_blockwise_copy =
1523 BElementwiseOperation,
1527 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1528 BBlockTransferThreadClusterArrangeOrder,
1529 BDataType,
1530 BDataType,
1531 decltype(b_grid_desc_bk0_n_bk1),
1532 decltype(b_block_desc_bk0_n_bk1),
1533 BBlockTransferSrcAccessOrder,
1535 BBlockTransferSrcVectorDim,
1536 2,
1537 BBlockTransferSrcScalarPerVector,
1538 BBlockTransferDstScalarPerVector_BK1,
1539 1,
1540 1,
1541 BThreadTransferSrcResetCoordinateAfterRun,
1542 true,
1543 BlockwiseGemmPipe::GlobalBufferNum>(
1544 b_grid_desc_bk0_n_bk1,
1545 make_multi_index(0, n_block_data_idx_on_grid, 0),
1546 b_element_op,
1547 b_block_desc_bk0_n_bk1,
1548 make_multi_index(0, 0, 0),
1550#else
1551 const auto idx_bs_block_begin =
1552 generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
1554
1555 auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
1557 BsDataType,
1559 decltype(bs_grid_desc_bk0_n_bk1),
1560 decltype(tie(b_block_desc_bk0_n_bk1)),
1561 BElementwiseOperation,
1564 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1565 BBlockTransferThreadClusterArrangeOrder,
1566 BBlockTransferSrcAccessOrder,
1568 BBlockTransferSrcVectorDim,
1569 2,
1570 BBlockTransferSrcScalarPerVector,
1571 BBlockTransferDstScalarPerVector_BK1,
1574 BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
1575 idx_bs_block_begin,
1576 tie(b_block_desc_bk0_n_bk1),
1577 make_tuple(make_multi_index(0, 0, 0)),
1578 b_element_op};
1579
1580#endif
1581
1582 // LDS allocation for A and B: be careful of alignment
1583 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1584 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1585
1586 // Cast after lds
1588 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1589
1591 static_cast<LDSTypeB*>(p_shared) +
1592 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
1593 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1594
1595 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1596 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1597
1598 // Blockwise GEMM pipeline
1599 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1600 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1601 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1602
1603 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1604 (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
1605 KPerBlock);
1606
1607 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(as_grid_desc_ak0_m_ak1,
1608 a_block_desc_ak0_m_ak1,
1609 a_blockwise_copy,
1610 as_grid_buf,
1611 a_block_buf,
1612 a_block_slice_copy_step,
1613 bs_grid_desc_bk0_n_bk1,
1614 b_block_desc_bk0_n_bk1,
1615 b_blockwise_copy,
1616 bs_grid_buf,
1617 b_block_buf,
1618 b_block_slice_copy_step,
1619 c_thread_buf,
1620 num_k_block_main_loop);
1621
1622 // shuffle C and write out
1623 {
1624 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1625 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1626 "wrong!");
1627
1628 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1629 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1630
1631 // TODO: hacky, fix it!
1632 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1633 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1634
1635 // TODO: hacky, fix it!
1636 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1637 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1638 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1639
1640 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1641 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1642 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1643 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1644 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1645 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1646 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1647 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1648
1649 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1651
1652 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1653 static_cast<CShuffleDataType*>(p_shared),
1654 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1655
1656 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1657 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1658 make_tuple(
1661 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1662 M1, // M1 = MWave
1663 M2, // M2 * M3 * M4 = MPerXdl
1664 M3,
1665 M4)),
1668 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1669 N1, // N1 = NWave
1670 N2))), // N2 = NPerXdl
1672 make_tuple(
1674
1675 // calculate origin of thread output tensor on global memory
1676 // blockwise GEMM c matrix starting index
1677 const auto c_thread_mtx_on_block =
1678 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1679
1680 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1681 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1682
1683 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1685 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1688
1689 const auto m_thread_data_on_block_idx =
1690 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1691 make_multi_index(m_thread_data_on_block));
1692
1693 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1698
1699 const auto n_thread_data_on_block_idx =
1700 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1701 make_multi_index(n_thread_data_on_block));
1702
1703 // shuffle: threadwise copy C from VGPR to LDS
1704 auto c_thread_copy_vgpr_to_lds =
1706 CShuffleDataType,
1707 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1708 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1710 Sequence<CShuffleMXdlPerWavePerShuffle,
1711 CShuffleNXdlPerWavePerShuffle,
1712 I1,
1713 I1,
1714 M2,
1715 I1,
1716 M4,
1717 I1>,
1719 7,
1720 1,
1722 1,
1723 true>{
1724 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1726 0,
1727 m_thread_data_on_block_idx[I1],
1728 n_thread_data_on_block_idx[I1],
1729 m_thread_data_on_block_idx[I2],
1730 m_thread_data_on_block_idx[I3],
1731 m_thread_data_on_block_idx[I4],
1732 n_thread_data_on_block_idx[I2]),
1734
1735#if 0
1736 // shuffle: blockwise copy C from LDS to global
1737 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1738 ThisThreadBlock, // ThreadGroup
1739 CElementwiseOperation, // ElementwiseOperation,
1740 CGlobalMemoryDataOperation, // DstInMemOp,
1741 Sequence<1,
1742 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1743 1,
1744 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1745 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1746 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1747 CShuffleDataType, // typename SrcData,
1748 CDataType, // typename DstData,
1749 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1750 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1751 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1752 3, // index_t VectorDim,
1753 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1754 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1755 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1756 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1757 make_multi_index(0, 0, 0, 0),
1758 c_grid_desc_mblock_mperblock_nblock_nperblock,
1759 make_multi_index(block_m_id, 0, block_n_id, 0),
1760 c_element_op};
1761#else
1762 using EDataType = CDataType;
1763
1764 // tuple of reference to C/Ds tensor descriptors
1765 const auto c_ds_desc_refs = concat_tuple_of_reference(
1766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1767 generate_tie([&](auto i) -> const auto& // return type should be reference
1768 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1770
1771 // tuple of reference to C/Ds tensor descriptors
1772 const auto c_ds_buf_refs = concat_tuple_of_reference(
1773 tie(c_shuffle_block_buf),
1774 generate_tie([&](auto i) -> const auto& // return type should be reference
1775 { return ds_grid_buf[i]; },
1777
1778 // tuple of starting index of C/Ds blockwise copy
1779 const auto idx_c_ds_block_begin = container_concat(
1780 make_tuple(make_multi_index(0, 0, 0, 0)),
1782 [&](auto) {
1783 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1784 },
1786
1787 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1788 c_grid_desc_mblock_mperblock_nblock_nperblock;
1789
1790 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1791 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1792 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1793 const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
1794 CShuffleBlockTransferScalarPerVector_NPerBlock;
1795
1796 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
1798 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1800 decltype(c_ds_desc_refs),
1801 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1802 CElementwiseOperation,
1803 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1804 // support arbitray type
1805 Sequence<1,
1806 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1807 1,
1808 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1809 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1810 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1811 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1812 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1813 3, // index_t SrcVectorDim,
1814 3, // index_t DstVectorDim,
1815 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
1816 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
1820 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1821 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1822 {c_ds_desc_refs,
1823 idx_c_ds_block_begin,
1824 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1825 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1826 c_element_op};
1827
1828#endif
1829
1830 // space filling curve for threadwise C in VGPR
1831 constexpr auto sfc_c_vgpr =
1834 Sequence<CShuffleMXdlPerWavePerShuffle,
1835 CShuffleNXdlPerWavePerShuffle,
1836 1,
1837 1,
1838 M2,
1839 1,
1840 M4,
1841 1>>{};
1842
1843 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1844#if 0
1845 // space filling curve for shuffled blockwise C in global mem
1846 constexpr auto sfc_c_global =
1849 Sequence<1,
1850 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1851 1,
1852 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1853
1854
1855 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1856
1857#else
1858 // space filling curve for shuffled blockwise C/D/E
1859 constexpr auto sfc_cde_block =
1862 Sequence<1,
1863 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1864 1,
1865 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1866
1867 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1868#endif
1869
1870 static_for<0, num_access, 1>{}([&](auto access_id) {
1871 // make sure it's safe to write to LDS
1873
1874 // each thread write its data from VGPR to LDS
1875 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1876 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1877 c_thread_buf,
1878 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1879 c_shuffle_block_buf);
1880
1881 // make sure it's safe to read from LDS
1883
1884#if 0
1885 // each block copy its data from LDS to global
1886 c_shuffle_block_copy_lds_to_global.Run(
1887 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1888 c_shuffle_block_buf,
1889 c_grid_desc_mblock_mperblock_nblock_nperblock,
1890 c_grid_buf);
1891
1892 if constexpr(access_id < num_access - 1)
1893 {
1894 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1895
1896 // move on C
1897 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1898 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1899 }
1900#else
1901 // each block copy its data from LDS to global
1902 cde_block_copy_lds_and_global.Run(
1903 c_ds_desc_refs,
1904 c_ds_buf_refs,
1905 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1906 tie(c_grid_buf));
1907
1908 if constexpr(access_id < num_access - 1)
1909 {
1910 constexpr auto cde_lds_and_global_step =
1911 sfc_cde_block.GetForwardStep(access_id);
1912
1913 // move on Ds
1914 static_for<0, NumDTensor, 1>{}([&](auto i) {
1915 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1916 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1917 });
1918
1919 // move on E
1920 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1921 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1922 I0,
1923 cde_lds_and_global_step);
1924 }
1925#endif
1926 });
1927 }
1928 }
1929
1930#if 1
1931 template <bool HasMainKBlockLoop,
1932 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1933 TailNumber TailNum = TailNumber::Odd>
1934 __device__ static void Run_2Lds(AsGridPointer& p_as_grid,
1935 BsGridPointer& p_bs_grid,
1936 DsGridPointer& p_ds_grid,
1937 CDataType* p_c_grid,
1938 void* p_shared_0,
1939 void* p_shared_1,
1940 const Problem& problem,
1941 const AElementwiseOperation& a_element_op,
1942 const BElementwiseOperation& b_element_op,
1943 const CElementwiseOperation& c_element_op)
1944 {
1945 // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1946 // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1947 // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1948 // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1949 const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
1950 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
1951 const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
1952 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
1953 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1954 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1955
1956 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1958 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1959
1960 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1961 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1962
1963 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1965 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1966
1967 // const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1968 // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1969 // const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1970 // p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1971 const auto as_grid_buf = generate_tuple(
1972 [&](auto i) {
1974 p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
1975 },
1977
1978 const auto bs_grid_buf = generate_tuple(
1979 [&](auto i) {
1981 p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
1982 },
1984
1986 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1987
1988 const auto ds_grid_buf = generate_tuple(
1989 [&](auto i) {
1991 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1992 },
1994
1995 // divide block work by [M, N]
1996 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1997
1998 const auto block_work_idx =
1999 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
2000
2001 if(!block_2_ctile_map.ValidCTileIndex(
2002 block_work_idx,
2003 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
2004 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
2005 {
2006 return;
2007 }
2008
2009 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
2010 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
2011
2012 // HACK: this force m/n_block_data_idx_on_grid into SGPR
2013 const index_t m_block_data_idx_on_grid =
2014 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
2015
2016 const index_t n_block_data_idx_on_grid =
2017 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
2018
2019 // lds max alignment
2020 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
2021
2022 // A matrix in LDS memory, dst of blockwise copy
2023 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2024
2025 // B matrix in LDS memory, dst of blockwise copy
2026 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2027
2028#if 0
2029 // A matrix blockwise copy
2030 auto a_blockwise_copy =
2032 AElementwiseOperation,
2036 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2037 ABlockTransferThreadClusterArrangeOrder,
2038 ADataType,
2039 ADataType,
2040 decltype(a_grid_desc_ak0_m_ak1),
2041 decltype(a_block_desc_ak0_m_ak1),
2042 ABlockTransferSrcAccessOrder,
2044 ABlockTransferSrcVectorDim,
2045 2,
2046 ABlockTransferSrcScalarPerVector,
2047 ABlockTransferDstScalarPerVector_AK1,
2048 1,
2049 1,
2050 AThreadTransferSrcResetCoordinateAfterRun,
2051 true,
2052 BlockwiseGemmPipe::GlobalBufferNum>(
2053 a_grid_desc_ak0_m_ak1,
2054 make_multi_index(0, m_block_data_idx_on_grid, 0),
2055 a_element_op,
2056 a_block_desc_ak0_m_ak1,
2057 make_multi_index(0, 0, 0),
2059#else
2060 const auto idx_as_block_begin =
2061 generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
2063
2064 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
2066 AsDataType,
2068 decltype(as_grid_desc_ak0_m_ak1),
2069 decltype(tie(a_block_desc_ak0_m_ak1)),
2070 AElementwiseOperation,
2073 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2074 ABlockTransferThreadClusterArrangeOrder,
2075 ABlockTransferSrcAccessOrder,
2077 ABlockTransferSrcVectorDim,
2078 2,
2079 ABlockTransferSrcScalarPerVector,
2080 ABlockTransferDstScalarPerVector_AK1,
2083 BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1,
2084 idx_as_block_begin,
2085 tie(a_block_desc_ak0_m_ak1),
2086 make_tuple(make_multi_index(0, 0, 0)),
2087 a_element_op};
2088
2089#endif
2090
2091#if 0
2092 // B matrix blockwise copy
2093 auto b_blockwise_copy =
2095 BElementwiseOperation,
2099 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2100 BBlockTransferThreadClusterArrangeOrder,
2101 BDataType,
2102 BDataType,
2103 decltype(b_grid_desc_bk0_n_bk1),
2104 decltype(b_block_desc_bk0_n_bk1),
2105 BBlockTransferSrcAccessOrder,
2107 BBlockTransferSrcVectorDim,
2108 2,
2109 BBlockTransferSrcScalarPerVector,
2110 BBlockTransferDstScalarPerVector_BK1,
2111 1,
2112 1,
2113 BThreadTransferSrcResetCoordinateAfterRun,
2114 true,
2115 BlockwiseGemmPipe::GlobalBufferNum>(
2116 b_grid_desc_bk0_n_bk1,
2117 make_multi_index(0, n_block_data_idx_on_grid, 0),
2118 b_element_op,
2119 b_block_desc_bk0_n_bk1,
2120 make_multi_index(0, 0, 0),
2122#else
2123 const auto idx_bs_block_begin =
2124 generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
2126
2127 auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
2129 BsDataType,
2131 decltype(bs_grid_desc_bk0_n_bk1),
2132 decltype(tie(b_block_desc_bk0_n_bk1)),
2133 BElementwiseOperation,
2136 BBlockTransferThreadClusterLengths_BK0_N_BK1,
2137 BBlockTransferThreadClusterArrangeOrder,
2138 BBlockTransferSrcAccessOrder,
2140 BBlockTransferSrcVectorDim,
2141 2,
2142 BBlockTransferSrcScalarPerVector,
2143 BBlockTransferDstScalarPerVector_BK1,
2146 BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1,
2147 idx_bs_block_begin,
2148 tie(b_block_desc_bk0_n_bk1),
2149 make_tuple(make_multi_index(0, 0, 0)),
2150 b_element_op};
2151#endif
2152
2153 // LDS allocation for A and B: be careful of alignment
2154 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
2155 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
2156
2157 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2158 static_cast<LDSTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2159
2160 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2161 static_cast<LDSTypeB*>(p_shared_0) +
2162 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
2163 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2164
2165 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2166 static_cast<LDSTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2167
2168 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2169 static_cast<LDSTypeB*>(p_shared_1) +
2170 a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
2171 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2172
2173 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2174 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2175
2176 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2177 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
2178
2179 // Blockwise GEMM pipeline
2180 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2181 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2182 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2183
2184 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2185 (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
2186 KPerBlock);
2187
2188 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(as_grid_desc_ak0_m_ak1,
2189 a_block_desc_ak0_m_ak1,
2190 a_blockwise_copy,
2191 as_grid_buf,
2192 a_block_bufs,
2193 a_block_slice_copy_step,
2194 bs_grid_desc_bk0_n_bk1,
2195 b_block_desc_bk0_n_bk1,
2196 b_blockwise_copy,
2197 bs_grid_buf,
2198 b_block_bufs,
2199 b_block_slice_copy_step,
2200 c_thread_buf,
2201 num_k_block_main_loop);
2202
2203 // shuffle C and write out
2204 {
2205 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2206 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2207 "wrong!");
2208
2209 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2210 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
2211
2212 // TODO: hacky, fix it!
2213 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2214 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2215
2216 // TODO: hacky, fix it!
2217 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2218 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2219 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2220
2221 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2222 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2223 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2224 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2225 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2226 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2227 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2228 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2229
2230 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2232
2233 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2234 static_cast<CShuffleDataType*>(p_shared_0),
2235 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2236
2237 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2238 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2239 make_tuple(
2242 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2243 M1, // M1 = MWave
2244 M2, // M2 * M3 * M4 = MPerXdl
2245 M3,
2246 M4)),
2249 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2250 N1, // N1 = NWave
2251 N2))), // N2 = NPerXdl
2253 make_tuple(
2255
2256 // calculate origin of thread output tensor on global memory
2257 // blockwise GEMM c matrix starting index
2258 const auto c_thread_mtx_on_block =
2259 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2260
2261 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2262 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2263
2264 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2266 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2269
2270 const auto m_thread_data_on_block_idx =
2271 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2272 make_multi_index(m_thread_data_on_block));
2273
2274 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2279
2280 const auto n_thread_data_on_block_idx =
2281 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2282 make_multi_index(n_thread_data_on_block));
2283
2284 // shuffle: threadwise copy C from VGPR to LDS
2285 auto c_thread_copy_vgpr_to_lds =
2287 CShuffleDataType,
2288 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2289 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2291 Sequence<CShuffleMXdlPerWavePerShuffle,
2292 CShuffleNXdlPerWavePerShuffle,
2293 I1,
2294 I1,
2295 M2,
2296 I1,
2297 M4,
2298 I1>,
2300 7,
2301 1,
2303 1,
2304 true>{
2305 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2307 0,
2308 m_thread_data_on_block_idx[I1],
2309 n_thread_data_on_block_idx[I1],
2310 m_thread_data_on_block_idx[I2],
2311 m_thread_data_on_block_idx[I3],
2312 m_thread_data_on_block_idx[I4],
2313 n_thread_data_on_block_idx[I2]),
2315
2316#if 0
2317 // shuffle: blockwise copy C from LDS to global
2318 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2319 ThisThreadBlock, // ThreadGroup
2320 CElementwiseOperation, // ElementwiseOperation,
2321 CGlobalMemoryDataOperation, // DstInMemOp,
2322 Sequence<1,
2323 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2324 1,
2325 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2326 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2327 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2328 CShuffleDataType, // typename SrcData,
2329 CDataType, // typename DstData,
2330 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2331 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2332 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2333 3, // index_t VectorDim,
2334 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2335 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2336 false> // bool ThreadTransferDstResetCoordinateAfterRun>
2337 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2338 make_multi_index(0, 0, 0, 0),
2339 c_grid_desc_mblock_mperblock_nblock_nperblock,
2340 make_multi_index(block_m_id, 0, block_n_id, 0),
2341 c_element_op};
2342#else
2343 using EDataType = CDataType;
2344
2345 // tuple of reference to C/Ds tensor descriptors
2346 const auto c_ds_desc_refs = concat_tuple_of_reference(
2347 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2348 generate_tie([&](auto i) -> const auto& // return type should be reference
2349 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2351
2352 // tuple of reference to C/Ds tensor descriptors
2353 const auto c_ds_buf_refs = concat_tuple_of_reference(
2354 tie(c_shuffle_block_buf),
2355 generate_tie([&](auto i) -> const auto& // return type should be reference
2356 { return ds_grid_buf[i]; },
2358
2359 // tuple of starting index of C/Ds blockwise copy
2360 const auto idx_c_ds_block_begin = container_concat(
2361 make_tuple(make_multi_index(0, 0, 0, 0)),
2363 [&](auto) {
2364 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
2365 },
2367
2368 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2369 c_grid_desc_mblock_mperblock_nblock_nperblock;
2370
2371 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
2372 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2373 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2374 const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock =
2375 CShuffleBlockTransferScalarPerVector_NPerBlock;
2376
2377 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
2379 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2381 decltype(c_ds_desc_refs),
2382 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2383 CElementwiseOperation,
2384 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
2385 // support arbitray type
2386 Sequence<1,
2387 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2388 1,
2389 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2390 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2391 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2392 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2393 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2394 3, // index_t SrcVectorDim,
2395 3, // index_t DstVectorDim,
2396 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
2397 CDEShuffleBlockTransferScalarPerVector_NPerBlock,
2401 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2402 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
2403 {c_ds_desc_refs,
2404 idx_c_ds_block_begin,
2405 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2406 make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
2407 c_element_op};
2408
2409#endif
2410
2411 // space filling curve for threadwise C in VGPR
2412 constexpr auto sfc_c_vgpr =
2415 Sequence<CShuffleMXdlPerWavePerShuffle,
2416 CShuffleNXdlPerWavePerShuffle,
2417 1,
2418 1,
2419 M2,
2420 1,
2421 M4,
2422 1>>{};
2423
2424 // space filling curve for shuffled blockwise C in global mem
2425 constexpr auto sfc_c_global =
2428 Sequence<1,
2429 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2430 1,
2431 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2432
2433 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2434
2435 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2436
2437#if 1
2438 // space filling curve for shuffled blockwise C/D/E
2439 constexpr auto sfc_cde_block =
2442 Sequence<1,
2443 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2444 1,
2445 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2446#endif
2447
2448 static_for<0, num_access, 1>{}([&](auto access_id) {
2449 // make sure it's safe to write to LDS
2451
2452 // each thread write its data from VGPR to LDS
2453 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2454 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2455 c_thread_buf,
2456 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2457 c_shuffle_block_buf);
2458
2459 // make sure it's safe to read from LDS
2461
2462#if 0
2463 // each block copy its data from LDS to global
2464 c_shuffle_block_copy_lds_to_global.Run(
2465 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2466 c_shuffle_block_buf,
2467 c_grid_desc_mblock_mperblock_nblock_nperblock,
2468 c_grid_buf);
2469
2470 if constexpr(access_id < num_access - 1)
2471 {
2472 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2473
2474 // move on C
2475 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2476 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2477 }
2478#else
2479 // each block copy its data from LDS to global
2480 cde_block_copy_lds_and_global.Run(
2481 c_ds_desc_refs,
2482 c_ds_buf_refs,
2483 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2484 tie(c_grid_buf));
2485
2486 if constexpr(access_id < num_access - 1)
2487 {
2488 constexpr auto cde_lds_and_global_step =
2489 sfc_cde_block.GetForwardStep(access_id);
2490
2491 // move on Ds
2492 static_for<0, NumDTensor, 1>{}([&](auto i) {
2493 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2494 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2495 });
2496
2497 // move on E
2498 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2499 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2500 I0,
2501 cde_lds_and_global_step);
2502 }
2503#endif
2504 });
2505 }
2506 }
2507#endif
2508};
2509
2510} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:62
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:716
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:643
AsGridPointer p_as_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:678
BsGridPointer p_bs_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:679
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:760
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:642
__host__ Argument(std::array< const void *, NumATensor > p_as_grid_, std::array< const void *, NumBTensor > p_bs_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_c_grid_, index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:628
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:644
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:680
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:641
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:695
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:700
std::array< index_t, NumATensor > StrideAs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:609
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:642
CElementwiseOperation c_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:711
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:706
std::array< index_t, NumBTensor > StrideBs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:610
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:708
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:701
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:696
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:611
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:704
__host__ Problem(index_t M_, index_t N_, index_t K_, std::array< index_t, NumATensor > StrideAs_, std::array< index_t, NumBTensor > StrideBs_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:569
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:699
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:707
BElementwiseOperation b_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:710
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:705
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:703
AElementwiseOperation a_element_op_
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:709
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:596
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:765
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:814
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:815
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp:690
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:451
static constexpr auto BK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:261
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, BlkGemmPipeSched, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1112
static constexpr auto I2
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:250
static constexpr auto I7
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:255
static constexpr auto I5
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:253
static constexpr auto AK1Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:260
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1413
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:1437
static constexpr auto I6
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr auto I1
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:249
static constexpr auto I0
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr auto I3
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:251
static constexpr auto I4
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:252
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:369
static constexpr auto BK0Number
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:259
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition threadwise_tensor_slice_transfer.hpp:39
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129