gemm_universal_pipeline_ag_bg_cr_policy.hpp Source File

gemm_universal_pipeline_ag_bg_cr_policy.hpp Source File#

Composable Kernel: gemm_universal_pipeline_ag_bg_cr_policy.hpp Source File
gemm_universal_pipeline_ag_bg_cr_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12template <typename T, typename = void>
13struct has_a_tile_access_pattern : std::false_type
14{
15};
16
17template <typename T>
18struct has_a_tile_access_pattern<T, std::void_t<decltype(T::ATileAccessPattern)>> : std::true_type
19{
20};
21
22template <typename T, typename = void>
23struct has_b_tile_access_pattern : std::false_type
24{
25};
26
27template <typename T>
28struct has_b_tile_access_pattern<T, std::void_t<decltype(T::BTileAccessPattern)>> : std::true_type
29{
30};
31
32template <typename Derived>
34{
35#if defined(__gfx950__)
36 template <typename Problem>
37 static constexpr bool is_a_load_tr =
38 std::is_same_v<remove_cvref_t<typename Problem::ALayout>, tensor_layout::gemm::ColumnMajor>;
39 template <typename Problem>
40 static constexpr bool is_b_load_tr =
41 std::is_same_v<remove_cvref_t<typename Problem::BLayout>, tensor_layout::gemm::RowMajor>;
42#else
43 template <typename Problem>
44 static constexpr bool is_a_load_tr = false;
45 template <typename Problem>
46 static constexpr bool is_b_load_tr = false;
47#endif
48
49 static constexpr auto I0 = number<0>{};
50 static constexpr auto I1 = number<1>{};
51 static constexpr auto I2 = number<2>{};
52
53 // Default tile access patterns
56
57 static constexpr auto getATileAccessPattern()
58 {
60 return Derived::ATileAccessPattern;
61 else
63 }
64
65 static constexpr auto getBTileAccessPattern()
66 {
68 return Derived::BTileAccessPattern;
69 else
71 }
72
73 template <typename Problem>
75 {
76
78 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
79 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
80
81 if constexpr(is_a_load_tr<Problem>)
82 {
83 // TODO: better lds descriptor for performance
84 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
88 number<1>{});
89 return a_lds_block_desc_0;
90 }
91 else
92 {
93 constexpr index_t KPack = GetSmemPackA<Problem>();
94
95 constexpr auto DataTypeSize = sizeof(ADataType);
96 constexpr auto MLdsLayer =
97 max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
98
99 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
101 number<MPerBlock / MLdsLayer>{},
102 number<KPack>{}),
105 number<1>{});
106
107 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
108 a_lds_block_desc_0,
110 number<KPerBlock / KPack * MLdsLayer>{})),
114
115 constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
116 a_lds_block_desc_permuted,
118 make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
123
124 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
125 a_lds_block_desc_xk0_mnldslayer_mn_xk1,
132
133 return a_lds_block_desc;
134 }
135 }
136
143 template <typename Problem>
145 {
147
148 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
149 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
150
151#if 1
152 if constexpr(is_b_load_tr<Problem>)
153 {
154 // TODO: better lds descriptor for performance
155 constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
159 number<1>{});
160 return b_lds_block_desc_0;
161 }
162 else
163 // else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
164 {
165 constexpr index_t KPack = GetSmemPackB<Problem>();
166 constexpr auto BK0 = number<KPerBlock / KPack>{};
167 constexpr auto DataTypeSize = sizeof(BDataType);
168 constexpr auto NLdsLayer =
169 max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
170
171 constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
173 BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
176 number<1>{});
177
178 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
179 b_lds_block_desc_0,
181 BK0 * number<NLdsLayer>{})),
185
186 constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
187 b_lds_block_desc_permuted,
193
194 constexpr auto b_lds_block_desc = transform_tensor_descriptor(
195 b_lds_block_desc_bk0_nldslayer_n_bk1,
201 return b_lds_block_desc;
202 }
203#else
204 else // B is Row Major
205 {
206 constexpr index_t BlockSize = Problem::kBlockSize;
207 constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
208 using TileEncodingPattern =
210 KPerBlock,
211 NPerBlock,
212 VecLoadSize,
214
215 constexpr auto BK0 = number<TileEncodingPattern::X1>{};
216 constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
217 // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
218 constexpr auto N0 = TileEncodingPattern::X0;
219 constexpr auto N1 = NPerBlock / N0;
220
221 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
222 constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
223
224 // constexpr auto KThreadWrite =
225 // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
226 constexpr auto KThreadWrite = TileEncodingPattern::Y2;
227 constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
228 constexpr auto KThreadRead = 64 / NPerXdl;
229 constexpr auto K0PerThreadRead = BK0 / KThreadRead;
230
231 constexpr auto kfold =
232 (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType));
233 constexpr auto KThreadReadPerm =
234 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
235 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
236 : KThreadRead;
237
238 // 1<=npair<=n0
239 constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128)
240 ? 1
241 : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0
242 ? N0
243 : 128 / (BK1 * NPerXdl * sizeof(BDataType)));
244
245 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
249 number<kfold * N0 / npair>{},
251 BK1));
252
253 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
254 b_lds_block_desc,
259 make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
266
267 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
268 b_lds_block_desc_permuted,
277 sequence<1>{},
278 sequence<2>{},
279 sequence<3>{},
280 sequence<4>{},
281 sequence<5>{}),
283 sequence<2>{},
286 sequence<6>{},
287 sequence<7>{}));
288
289 // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
290 // b_lds_block_desc_unmerged,
291 // make_tuple(make_merge_transform_v3_division_mod(
292 // make_tuple(number<KThreadReadPerm>{},
293 // number<KThreadWrite / kfold / KThreadReadPerm>{},
294 // number<kfold>{},
295 // number<K0PerThreadWrite>{})),
296 // make_merge_transform_v3_division_mod(
297 // make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
298 // make_pass_through_transform(BK1)),
299 // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
300 // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
301
302 constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor(
303 b_lds_block_desc_unmerged,
306 number<KThreadWrite / kfold / KThreadReadPerm>{},
309 BK1)),
314
315 // return b_lds_block_desc_bk0_n_bk1;
316 return b_lds_block_desc_kn;
317
318 // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
319 // make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
320 // make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
321 // number<KPack>{},
322 // number<1>{});
323
324 // constexpr auto b_lds_block_desc = transform_tensor_descriptor(
325 // b_lds_block_desc_bk0_n_bk1,
326 // make_tuple(make_pass_through_transform(number<NPerBlock>{}),
327 // make_merge_transform_v3_division_mod(make_tuple(BK0,
328 // number<KPack>{}))),
329 // make_tuple(sequence<1>{}, sequence<0, 2>{}),
330 // make_tuple(sequence<0>{}, sequence<1>{}));
331
332 // return b_lds_block_desc;
333 }
334#endif
335 }
336
346 template <typename Problem,
347 typename DataType,
348 index_t MNPerBlock,
349 index_t XPerTile,
350 bool IsWave32Host>
352 {
353 constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
354 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
355 constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
356 constexpr index_t PackedSize =
358
359 // Assume DataType is even!
360 if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
361 elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
362 PackedSize == 2)
363 {
364 return (PackedSize * 32 / sizeof(DataType));
365 }
366 else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
367 elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
368 {
369 return (PackedSize * 16 / sizeof(DataType));
370 }
371 else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
372 elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
373 {
374 return (PackedSize * 8 / sizeof(DataType));
375 }
376 else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
377 XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
378 elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
379 {
380 return (PackedSize * 4 / sizeof(DataType));
381 }
382 else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
383 XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
384 elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
385 {
386 return (PackedSize * 2 / sizeof(DataType));
387 }
388 else
389 {
390 return PackedSize;
391 }
392 }
393
394 template <typename Problem, bool IsWave32Host = false>
395 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
396 {
399 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
400 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
401
402 using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
403 using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
404
405 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
406 {
407 return GetGlobalVectorLoadSize<Problem,
408 ADataType,
409 MPerBlock,
410 KPerBlock,
411 IsWave32Host>();
412 }
413 else
414 {
415 return GetGlobalVectorLoadSize<Problem,
416 ADataType,
417 MPerBlock,
418 MPerBlock,
419 IsWave32Host>();
420 }
421 }
422
423 template <typename Problem, bool IsWave32Host = false>
424 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
425 {
428 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
429 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
430
431 using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
432 using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
433
434 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
435 {
436 return GetGlobalVectorLoadSize<Problem,
437 BDataType,
438 NPerBlock,
439 NPerBlock,
440 IsWave32Host>();
441 }
442 else
443 {
444 return GetGlobalVectorLoadSize<Problem,
445 BDataType,
446 NPerBlock,
447 KPerBlock,
448 IsWave32Host>();
449 }
450 }
451
464 template <typename Problem>
465 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
466 {
468 using WG = typename BlockGemm::WarpGemm;
469
470 constexpr bool TransposeC = Problem::TransposeC;
471 using CLayout = typename Problem::CLayout;
472 using CWarpDstr = typename WG::CWarpDstr;
473
474 // N is contiguous dimension
475 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
476 {
477 if constexpr(TransposeC)
478 {
479 // In this case each thread has multiple consecutive elements in
480 // N dimension, however consecutive threads' elements have stride.
481 constexpr index_t NDimY = CWarpDstr::NDimY;
482 constexpr auto c_warp_y_lengths =
483 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
484 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
485 c_warp_y_lengths.get(number<NDimY - 1>{}));
486 return c_warp_y_lengths.get(number<NDimY - 1>{});
487 }
488 else
489 {
490 // In this case each thread has just a single item in Ndim
491 return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
492 }
493 }
494 // M is contiguous dimension
495 else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
496 {
497 if constexpr(TransposeC)
498 {
499 // In this case each thread has just a single item in Mdim
500 return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
501 }
502 else
503 {
504 // In this case each thread has multiple consecutive elements in
505 // M dimension, however consecutive threads' elements have stride.
506 constexpr index_t NDimY = CWarpDstr::NDimY;
507 constexpr auto c_warp_y_lengths =
508 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
509 static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
510 c_warp_y_lengths.get(number<NDimY - 1>{}));
511 return c_warp_y_lengths.get(number<NDimY - 1>{});
512 }
513 }
514 else
515 {
516 static_assert(false, "Unsupported CLayout!");
517 }
518 }
519
520 template <typename Problem>
521 CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
522 {
523 return Problem::TransposeC;
524 }
525
526 template <typename Problem>
528 {
529 constexpr index_t BlockSize = Problem::kBlockSize;
530 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
531 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
532 constexpr index_t VecLoadSize =
533 Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
534 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
535
536 using ALayout = remove_cvref_t<
537 std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
538 // Tile: MPerBlock X KPerBlock
539 if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
540 {
541 using TileEncodingPattern =
543 MPerBlock,
544 KPerBlock,
545 VecLoadSize,
547 NumWaveGroups>;
548 return TileEncodingPattern::make_2d_static_tile_distribution();
549 }
550 // Tile: KPerBlock X MPerBlock
551 else
552 {
553 using TileEncodingPattern =
555 KPerBlock,
556 MPerBlock,
557 VecLoadSize,
559 NumWaveGroups>;
560 return TileEncodingPattern::make_2d_static_tile_distribution();
561 }
562 }
563
564 template <typename Problem>
566 {
567 constexpr index_t BlockSize = Problem::kBlockSize;
568 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
569 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
570 constexpr index_t VecLoadSize =
571 Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
572 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
573
574 using BLayout = remove_cvref_t<
575 std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
576 // Tile: KPerBlock X NPerBlock
577 if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
578 {
579 using TileEncodingPattern =
581 KPerBlock,
582 NPerBlock,
583 VecLoadSize,
585 NumWaveGroups>;
586 return TileEncodingPattern::make_2d_static_tile_distribution();
587 }
588 // Tile: NPerBlock X KPerBlock
589 else
590 {
591 using TileEncodingPattern =
593 NPerBlock,
594 KPerBlock,
595 VecLoadSize,
597 NumWaveGroups>;
598 return TileEncodingPattern::make_2d_static_tile_distribution();
599 }
600 }
601
602 template <typename Problem>
604 {
605 using ALayout = remove_cvref_t<
606 std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
607 static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
608 constexpr index_t BlockSize = Problem::kBlockSize;
609 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
610 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
611 constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
612 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
613
614 using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
615 KPerBlock,
616 MPerBlock,
617 VecLoadSize,
619 NumWaveGroups>;
620 return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
621 }
622
623 template <typename Problem>
625 {
626 using BLayout = remove_cvref_t<
627 std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
628 static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
629 constexpr index_t BlockSize = Problem::kBlockSize;
630 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
631 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
632 constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
633 constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
634
635 using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
636 KPerBlock,
637 NPerBlock,
638 VecLoadSize,
640 NumWaveGroups>;
641 return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
642 }
643
644 template <typename Problem>
645 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
646 {
648 constexpr index_t KPack = BlockGemm::Traits::KPack;
649 return KPack;
650 }
651
652 template <typename Problem>
653 CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
654 {
656 constexpr index_t KPack = BlockGemm::Traits::KPack;
657 return KPack;
658 }
659
660 template <typename Problem>
662 {
663 constexpr index_t smem_size_a =
664 integer_least_multiple(sizeof(typename Problem::ADataType) *
665 Problem::BlockGemmShape::kM * Problem::BlockGemmShape::kK,
666 16);
667 return smem_size_a;
668 }
669
670 template <typename Problem>
672 {
673 constexpr index_t smem_size_b =
674 integer_least_multiple(sizeof(typename Problem::BDataType) *
675 Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
676 16);
677 return smem_size_b;
678 }
679
680 template <typename Problem>
682 {
683 constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
684 constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
685
686 return smem_size_a + smem_size_b;
687 }
688};
689
690// UniversalGemm Policy
692 : public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
693{
694 template <typename Problem>
695 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
696 {
697 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
698 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
699
700 constexpr index_t vector_size =
701 DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
702 constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
703 constexpr auto wg_attr_num_access =
705 : vector_size == thread_elements ? WGAttrNumAccessEnum::Single
706 : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
707 : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
709
710 using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
711 typename Problem::ComputeDataType,
712 typename Problem::CDataType,
713 WarpTile::at(I0),
714 WarpTile::at(I1),
715 WarpTile::at(I2),
716 Problem::TransposeC,
717 false,
718 Problem::UseStructuredSparsity,
719 wg_attr_num_access>;
720
721 using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
722 typename Problem::BDataType,
723 typename Problem::CDataType,
724 BlockWarps,
725 WarpGemm>;
727 }
728};
729
730} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
@ Invalid
Definition warp_gemm_attribute_mfma.hpp:17
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
@ Quad
Definition warp_gemm_attribute_mfma.hpp:16
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed(const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:371
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
@ thread_raked
Thread raked pattern.
Definition static_encoding_pattern.hpp:94
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
constexpr int DS_READ_TR_SIZE()
Definition load_tile_transpose.hpp:20
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
STL namespace.
Definition block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp:18
Definition block_universal_gemm_as_bs_cr.hpp:21
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledARegTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:603
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static CK_TILE_DEVICE constexpr index_t GetSmemSizeB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:671
static constexpr auto getATileAccessPattern()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:57
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:645
static constexpr auto DefaultATileAccessPattern
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:54
static constexpr auto DefaultBTileAccessPattern
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:55
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:653
static CK_TILE_HOST_DEVICE constexpr auto MakeADramTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:527
static CK_TILE_DEVICE constexpr auto MakeBLdsBlockDescriptor()
Create LDS block descriptor for B tensor.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:144
static constexpr bool is_a_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:44
static constexpr bool is_b_load_tr
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:46
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeC()
Get the vector store size for C tensor.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:465
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:424
static CK_TILE_HOST_DEVICE constexpr auto GetGlobalVectorLoadSize()
Get the maximum global memory vector load size.
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:351
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBRegTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:624
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto getBTileAccessPattern()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:65
static CK_TILE_DEVICE constexpr index_t GetSmemSizeA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:661
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:395
static CK_TILE_HOST_DEVICE constexpr auto MakeBDramTileDistribution()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:565
static CK_TILE_HOST_DEVICE constexpr auto IsTransposeC()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:521
static CK_TILE_DEVICE constexpr auto MakeALdsBlockDescriptor()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:74
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
static CK_TILE_DEVICE constexpr index_t GetSmemSize()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:681
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:693
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:695
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:14
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:24
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17
Class creating 2D static tile distribution with different load/store patterns.
Definition static_encoding_pattern.hpp:130