device_gemm_wmma.hpp Source File

device_gemm_wmma.hpp Source File#

Composable Kernel: device_gemm_wmma.hpp Source File
device_gemm_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22namespace tensor_operation {
23namespace device {
24
25template <typename ALayout,
26 typename BLayout,
27 typename CLayout,
28 typename ADataType,
29 typename BDataType,
30 typename CDataType,
31 typename AccDataType,
32 typename CShuffleDataType,
33 typename AElementwiseOperation,
34 typename BElementwiseOperation,
35 typename CElementwiseOperation,
36 GemmSpecialization GemmSpec,
37 ck::index_t NumPrefetch,
38 ck::index_t BlockSize,
39 ck::index_t MPerBlock,
40 ck::index_t NPerBlock,
41 ck::index_t KPerBlock,
42 ck::index_t K1,
43 ck::index_t MPerWmma,
44 ck::index_t NPerWmma,
45 ck::index_t MRepeat,
46 ck::index_t NRepeat,
47 typename ABlockTransferThreadClusterLengths_K0_M_K1,
48 typename ABlockTransferThreadClusterArrangeOrder,
49 typename ABlockTransferSrcAccessOrder,
50 ck::index_t ABlockTransferSrcVectorDim,
51 ck::index_t ABlockTransferSrcScalarPerVector,
52 ck::index_t ABlockTransferDstScalarPerVector_K1,
53 bool ABlockLdsAddExtraM,
54 typename BBlockTransferThreadClusterLengths_K0_N_K1,
55 typename BBlockTransferThreadClusterArrangeOrder,
56 typename BBlockTransferSrcAccessOrder,
57 ck::index_t BBlockTransferSrcVectorDim,
58 ck::index_t BBlockTransferSrcScalarPerVector,
59 ck::index_t BBlockTransferDstScalarPerVector_K1,
60 bool BBlockLdsAddExtraN,
61 index_t CShuffleMRepeatPerShuffle,
62 index_t CShuffleNRepeatPerShuffle,
63 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
64 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
67struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
68 BLayout,
69 CLayout,
70 ADataType,
71 BDataType,
72 CDataType,
73 AElementwiseOperation,
74 BElementwiseOperation,
75 CElementwiseOperation>
76{
77 static constexpr auto I0 = Number<0>{};
78 static constexpr auto I1 = Number<1>{};
79 static constexpr auto I2 = Number<2>{};
80 static constexpr auto I3 = Number<3>{};
81 static constexpr auto I4 = Number<4>{};
82 static constexpr auto I5 = Number<5>{};
83 static constexpr auto I6 = Number<6>{};
84 // K1 = Max Vector Access Pixels
85 static constexpr auto K1Number = Number<K1>{};
86
87 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
88 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
89 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
90 static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
91 static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
92
93 static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
95 ? false
96 : true;
97 static constexpr auto BEnableLds_auto =
98 (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
100 ? false
101 : true;
102
103 // If true, LDS is used unconditionally
104 static constexpr auto AEnableLds_manu = false;
105 static constexpr auto BEnableLds_manu = false;
106
107 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
108 static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
109
110 static constexpr auto matrix_padder =
111 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
112 // Describe how data read from Global memory
113 static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
114 {
115 const auto a_grid_desc_m_k = [&]() {
117 {
118 const auto a_grid_desc_mraw_kraw =
120
121 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
122 }
124 {
125 const auto a_grid_desc_mraw_kraw =
127
128 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
129 }
130 }();
131
132 const auto M = a_grid_desc_m_k.GetLength(I0);
133 const auto K = a_grid_desc_m_k.GetLength(I1);
134 assert(K % K1 == 0);
135
136 if constexpr(AEnableLds)
137 {
138 const index_t K0 = K / K1;
139
141 a_grid_desc_m_k,
146 }
147 else
148 {
149 constexpr auto A_KRow = 2;
150 constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
151 const auto A_KWmma = K / WmmaK;
152
153 const auto M0 = M / MPerBlock;
154 // 0 1 0 1 2 3 4 5 6
155 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
157 a_grid_desc_m_k,
161 make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
164 }
165 }
166
167 static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
168 {
169 const auto b_grid_desc_n_k = [&]() {
171 {
172 const auto b_grid_desc_nraw_kraw =
174
175 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
176 }
178 {
179 const auto b_grid_desc_nraw_kraw =
181
182 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
183 }
184 }();
185
186 const auto N = b_grid_desc_n_k.GetLength(I0);
187 const auto K = b_grid_desc_n_k.GetLength(I1);
188 assert(K % K1 == 0);
189
190 if constexpr(BEnableLds)
191 {
192 const index_t K0 = K / K1;
193
195 b_grid_desc_n_k,
200 }
201 else
202 {
203 constexpr auto B_KRow = 2;
204 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
205 const auto B_KWmma = K / WmmaK;
206
207 const auto N0 = N / NPerBlock;
208 // 0 1 0 1 2 3 4 5 6
209 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
211 b_grid_desc_n_k,
215 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
218 }
219 }
220
221 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
222 {
223 const auto c_grid_desc_mraw_nraw = [&]() {
225 {
226 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
227 make_tuple(StrideC, I1));
228 }
230 {
231 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
232 make_tuple(I1, StrideC));
233 }
234 }();
235
236 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
237 }
238
239 // Gridwise descriptor, mapping to whole given provblem.
240 using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
241 using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
242 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
243
244 // GridwiseGemm
246 GridwiseGemm_Wmma<BlockSize,
247 ADataType,
248 BDataType,
249 AccDataType,
250 CShuffleDataType,
251 CDataType,
253 AGridDesc,
254 BGridDesc,
256 AElementwiseOperation,
257 BElementwiseOperation,
258 CElementwiseOperation,
259 MPerBlock,
260 NPerBlock,
261 KPerBlock,
262 MPerWmma,
263 NPerWmma,
264 K1,
265 MRepeat,
266 NRepeat,
267 ABlockTransferThreadClusterLengths_K0_M_K1,
268 ABlockTransferThreadClusterArrangeOrder,
269 ABlockTransferSrcAccessOrder,
270 ABlockTransferSrcVectorDim,
271 ABlockTransferSrcScalarPerVector,
272 ABlockTransferDstScalarPerVector_K1,
273 false, // AThreadTransferSrcResetCoordinateAfterRun,
275 ABlockLdsAddExtraM,
276 BBlockTransferThreadClusterLengths_K0_N_K1,
277 BBlockTransferThreadClusterArrangeOrder,
278 BBlockTransferSrcAccessOrder,
279 BBlockTransferSrcVectorDim,
280 BBlockTransferSrcScalarPerVector,
281 BBlockTransferDstScalarPerVector_K1,
282 false, // BThreadTransferSrcResetCoordinateAfterRun,
284 BBlockLdsAddExtraN,
285 CShuffleMRepeatPerShuffle,
286 CShuffleNRepeatPerShuffle,
287 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
288 CShuffleBlockTransferScalarPerVector_NPerBlock,
289 NumPrefetch,
290 LoopSched,
291 PipelineVer>;
292
293 // Argument
294 struct Argument : public BaseArgument
295 {
296 Argument(const ADataType* p_a_grid,
297 const BDataType* p_b_grid,
298 CDataType* p_c_grid,
299 index_t M,
300 index_t N,
301 index_t K,
302 index_t StrideA,
303 index_t StrideB,
304 index_t StrideC,
305 index_t M01,
306 index_t N01,
307 AElementwiseOperation a_element_op,
308 BElementwiseOperation b_element_op,
309 CElementwiseOperation c_element_op)
310 : p_a_grid_{p_a_grid},
311 p_b_grid_{p_b_grid},
312 p_c_grid_{p_c_grid},
313 a_grid_desc_{},
318 M01_{M01},
319 N01_{N01},
320 a_element_op_{a_element_op},
321 b_element_op_{b_element_op},
322 c_element_op_{c_element_op},
323 MRaw_{M},
324 NRaw_{N},
325 KRaw_{K}
326 {
330
333
336 {
340 }
341 }
342
343 // private:
344 const ADataType* p_a_grid_;
345 const BDataType* p_b_grid_;
346 CDataType* p_c_grid_;
355 AElementwiseOperation a_element_op_;
356 BElementwiseOperation b_element_op_;
357 CElementwiseOperation c_element_op_;
358 // for checking vector load/store
362 };
363
364 // Invoker
365 struct Invoker : public BaseInvoker
366 {
368
369 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
370 {
375 {
376 throw std::runtime_error(
377 "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
378 }
379
380 const index_t grid_size =
381 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
382
383 const auto K = [&]() {
384 if constexpr(AEnableLds)
385 {
386 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
387 }
388 else
389 {
390 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
391 arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
392 }
393 }();
394 auto launch_kernel = [&](auto has_main_k_block_loop) {
395 const auto kernel = kernel_gemm_wmma<
396 GridwiseGemm,
397 ADataType,
398 BDataType,
399 CDataType,
404 AElementwiseOperation,
405 BElementwiseOperation,
406 CElementwiseOperation,
408 has_main_k_block_loop>;
409
410 return launch_and_time_kernel(stream_config,
411 kernel,
412 dim3(grid_size),
413 dim3(BlockSize),
414 0,
415 arg.p_a_grid_,
416 arg.p_b_grid_,
417 arg.p_c_grid_,
418 arg.a_grid_desc_,
421 arg.a_element_op_,
422 arg.b_element_op_,
423 arg.c_element_op_,
425 };
426
428 {
429 return launch_kernel(integral_constant<bool, true>{});
430 }
431 else
432 {
433 return launch_kernel(integral_constant<bool, false>{});
434 }
435 }
436
437 // polymorphic
438 float Run(const BaseArgument* p_arg,
439 const StreamConfig& stream_config = StreamConfig{}) override
440 {
441 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
442 }
443 };
444
445 static constexpr bool IsValidCompilationParameter()
446 {
447 // TODO: properly implement this check
448 return true;
449 }
450
451 static bool IsSupportedArgument(const Argument& arg)
452 {
454 {
457 {
458 printf("DeviceOp err: AccDataType");
459 return false;
460 }
461 }
462 else
463 {
464 printf("DeviceOp err: Arch");
465 return false;
466 }
467
468 // check vector load/store
469 {
472
473 // check vector load of A
474 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
475 {
476 if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
477 {
478 return false;
479 }
480 }
481 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
482 {
483 // FIXME: not rigorous
484 if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
485 {
486 return false;
487 }
488 }
489 else
490 {
491 return false;
492 }
493
494 // check vector laod of B
495 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
496 {
497 if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
498 {
499 return false;
500 }
501 }
502 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
503 {
504 // FIXME: not rigorous
505 if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
506 {
507 return false;
508 }
509 }
510 else
511 {
512 return false;
513 }
514
515 // check vector store of C
516 // only support RowMajor for now
517 if constexpr(is_same_v<CLayout, Row>)
518 {
519 if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
520 {
521 return false;
522 }
523 }
524 else
525 {
526 return false;
527 }
528 }
529
534 }
535
536 // polymorphic
537 bool IsSupportedArgument(const BaseArgument* p_arg) override
538 {
539 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
540 }
541
542 static auto MakeArgument(const ADataType* p_a,
543 const BDataType* p_b,
544 CDataType* p_c,
545 index_t M,
546 index_t N,
547 index_t K,
548 index_t StrideA,
549 index_t StrideB,
550 index_t StrideC,
551 AElementwiseOperation a_element_op,
552 BElementwiseOperation b_element_op,
553 CElementwiseOperation c_element_op)
554 {
555 return Argument{p_a,
556 p_b,
557 p_c,
558 M,
559 N,
560 K,
561 StrideA,
562 StrideB,
563 StrideC,
564 1,
565 1,
566 a_element_op,
567 b_element_op,
568 c_element_op};
569 }
570
571 static auto MakeInvoker() { return Invoker{}; }
572
573 // polymorphic
574 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
575 const void* p_b,
576 void* p_c,
577 index_t M,
578 index_t N,
579 index_t K,
580 index_t StrideA,
581 index_t StrideB,
582 index_t StrideC,
583 AElementwiseOperation a_element_op,
584 BElementwiseOperation b_element_op,
585 CElementwiseOperation c_element_op) override
586 {
587 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
588 static_cast<const BDataType*>(p_b),
589 static_cast<CDataType*>(p_c),
590 M,
591 N,
592 K,
593 StrideA,
594 StrideB,
595 StrideC,
596 1,
597 1,
598 a_element_op,
599 b_element_op,
600 c_element_op);
601 }
602
603 // polymorphic
604 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
605 {
606 return std::make_unique<Invoker>(Invoker{});
607 }
608
609 // polymorphic
610 std::string GetTypeString() const override
611 {
612 auto str = std::stringstream();
613
614 std::map<LoopScheduler, std::string> LoopSchedToString{
615 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
616
617 std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
618 {PipelineVersion::v2, "v2"}};
619
620 // clang-format off
621 str << "DeviceGemmWmma_CShuffle"
622 << "<"
623 << BlockSize << ", "
624 << MPerBlock << ", "
625 << NPerBlock << ", "
626 << KPerBlock << ", "
627 << K1 << ", "
628 << MPerWmma << ", "
629 << NPerWmma << ", "
630 << MRepeat << ", "
631 << NRepeat
632 << ">"
633 << " AEnableLds: "
634 << AEnableLds << ", "
635 << "BEnableLds: "
636 << BEnableLds << ", "
637 << "NumPrefetch: "
638 << NumPrefetch << ", "
639 << "LoopScheduler: "
640 << LoopSchedToString[LoopSched] << ", "
641 << "PipelineVersion: "
642 << PipelineVersionToString[PipelineVer];
643 // clang-format on
644
645 return str.str();
646 }
647};
648
649} // namespace device
650} // namespace tensor_operation
651} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_wmma.hpp:37
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_wmma.hpp:124
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
Definition device_gemm.hpp:22
CElementwiseOperation c_element_op_
Definition device_gemm_wmma.hpp:357
AElementwiseOperation a_element_op_
Definition device_gemm_wmma.hpp:355
AGridDesc a_grid_desc_
Definition device_gemm_wmma.hpp:347
index_t M01_
Definition device_gemm_wmma.hpp:353
index_t NRaw_
Definition device_gemm_wmma.hpp:360
index_t N01_
Definition device_gemm_wmma.hpp:354
const ADataType * p_a_grid_
Definition device_gemm_wmma.hpp:344
BGridDesc b_grid_desc_k0_n_k1_
Definition device_gemm_wmma.hpp:348
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_gemm_wmma.hpp:352
index_t MRaw_
Definition device_gemm_wmma.hpp:359
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma.hpp:296
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_wmma.hpp:351
index_t KRaw_
Definition device_gemm_wmma.hpp:361
BElementwiseOperation b_element_op_
Definition device_gemm_wmma.hpp:356
CGridDesc_M_N c_grid_desc_m_n_
Definition device_gemm_wmma.hpp:349
CDataType * p_c_grid_
Definition device_gemm_wmma.hpp:346
const BDataType * p_b_grid_
Definition device_gemm_wmma.hpp:345
DeviceGemmWmma_CShuffle::Argument Argument
Definition device_gemm_wmma.hpp:367
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_wmma.hpp:438
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_wmma.hpp:369
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_gemm_wmma.hpp:221
static constexpr auto K1Number
Definition device_gemm_wmma.hpp:85
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_wmma.hpp:113
static auto MakeInvoker()
Definition device_gemm_wmma.hpp:571
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_wmma.hpp:537
static constexpr auto AEnableLds_manu
Definition device_gemm_wmma.hpp:104
static constexpr auto AEnableLds
Definition device_gemm_wmma.hpp:107
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_wmma.hpp:167
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma.hpp:451
static constexpr auto BEnableLds
Definition device_gemm_wmma.hpp:108
static constexpr auto I3
Definition device_gemm_wmma.hpp:80
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_gemm_wmma.hpp:242
static constexpr auto MaxVectorLoadA
Definition device_gemm_wmma.hpp:90
static constexpr auto I1
Definition device_gemm_wmma.hpp:78
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_wmma.hpp:574
static constexpr auto AEnableLds_auto
Definition device_gemm_wmma.hpp:93
std::string GetTypeString() const override
Definition device_gemm_wmma.hpp:610
static constexpr auto I6
Definition device_gemm_wmma.hpp:83
static constexpr auto I0
Definition device_gemm_wmma.hpp:77
static constexpr auto I2
Definition device_gemm_wmma.hpp:79
static constexpr auto I5
Definition device_gemm_wmma.hpp:82
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_gemm_wmma.hpp:241
static constexpr auto MWaves
Definition device_gemm_wmma.hpp:87
static constexpr auto NWaves
Definition device_gemm_wmma.hpp:88
static constexpr auto matrix_padder
Definition device_gemm_wmma.hpp:110
static constexpr auto WmmaK
Definition device_gemm_wmma.hpp:89
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_gemm_wmma.hpp:240
static constexpr auto BEnableLds_auto
Definition device_gemm_wmma.hpp:97
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_wmma.hpp:445
GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition device_gemm_wmma.hpp:245
static constexpr auto MaxVectorLoadB
Definition device_gemm_wmma.hpp:91
static constexpr auto BEnableLds_manu
Definition device_gemm_wmma.hpp:105
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_wmma.hpp:542
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_wmma.hpp:604
static constexpr auto I4
Definition device_gemm_wmma.hpp:81
Definition matrix_padder.hpp:180