device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp Source File

device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp Source File
device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <numeric>
8#include <sstream>
9
11#include "ck/utility/env.hpp"
28
32
33namespace ck {
34namespace tensor_operation {
35namespace device {
36
37template <typename GridwiseGemm,
38 typename AGridDesc_AK0_M_K1,
39 typename BGridDesc_BK0_N_K1,
40 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
41 typename ComputePtrOffsetOfBatch,
42 index_t NumGroupsToMerge,
43 bool HasMainKBlockLoop,
44 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
45 index_t MinimumOccupancy = 1,
47__global__ void
48#if CK_USE_LAUNCH_BOUNDS
49__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
50#endif
52 typename GridwiseGemm::Argument karg,
53 [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
54 [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
55 [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
56 c_grid_desc_mblock_mperblock_nblock_nperblock,
57 [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
58 [[maybe_unused]] const index_t num_k_per_block)
59{
60#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
61 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
62 {
63 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
64 const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
65
66 const long_index_t a_batch_offset = amd_wave_read_first_lane(
67 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
68 const long_index_t b_batch_offset = amd_wave_read_first_lane(
69 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
70 const long_index_t e_batch_offset = amd_wave_read_first_lane(
71 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
72
73 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
74
75 GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
76 BGridDesc_BK0_N_K1,
77 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
78 HasMainKBlockLoop,
79 CGlobalMemoryDataOperation,
80 TailNum>(karg.p_a_grid + a_batch_offset,
81 karg.p_b_grid + b_batch_offset,
82 karg.p_c_grid + e_batch_offset,
83 p_shared,
84 karg,
85 a_grid_desc_ak0_m_ak1,
86 b_grid_desc_bk0_n_bk1,
87 c_grid_desc_mblock_mperblock_nblock_nperblock,
88 k_idx);
89 }
90#else
91 ignore = karg;
92#endif // end of if (defined(__gfx9__))
93}
94
95template <typename GridwiseGemm,
96 typename AGridDesc_AK0_M_K1,
97 typename BGridDesc_BK0_N_K1,
98 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
99 typename ComputePtrOffsetOfBatch,
100 index_t NumGroupsToMerge,
101 bool HasMainKBlockLoop,
102 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
103 index_t MinimumOccupancy = 1,
105__global__ void
106#if CK_USE_LAUNCH_BOUNDS
107__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
108#endif
110 typename GridwiseGemm::Argument karg,
111 [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
112 [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
113 [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
114 c_grid_desc_mblock_mperblock_nblock_nperblock,
115 [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
116 [[maybe_unused]] const index_t num_k_per_block)
117{
118#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
119 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
120 {
121 // offset base pointer for each work-group
122 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
123 const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
124
125 const long_index_t a_batch_offset = amd_wave_read_first_lane(
126 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
127 const long_index_t b_batch_offset = amd_wave_read_first_lane(
128 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
129 const long_index_t e_batch_offset = amd_wave_read_first_lane(
130 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
131
132 // Pass two lds pointer is the key to tell compiler that ds_read/write
133 // operate on different lds chunk at same time without order dependecy
134 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
135 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
136
137 GridwiseGemm::template Run_2Lds<AGridDesc_AK0_M_K1,
138 BGridDesc_BK0_N_K1,
139 CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
140 HasMainKBlockLoop,
141 CGlobalMemoryDataOperation,
142 TailNum>(karg.p_a_grid + a_batch_offset,
143 karg.p_b_grid + b_batch_offset,
144 karg.p_c_grid + e_batch_offset,
145 p_shared_0,
146 p_shared_1,
147 karg,
148 a_grid_desc_ak0_m_ak1,
149 b_grid_desc_bk0_n_bk1,
150 c_grid_desc_mblock_mperblock_nblock_nperblock,
151 k_idx);
152 }
153#else
154 ignore = karg;
155#endif // end of if (defined(__gfx9__))
156}
157
158template <ck::index_t NDimSpatial,
159 typename InLayout,
160 typename WeiLayout,
161 typename OutLayout,
162 typename InDataType,
163 typename WeiDataType,
164 typename OutDataType,
165 typename AccDataType,
166 typename InElementwiseOperation,
167 typename WeiElementwiseOperation,
168 typename OutElementwiseOperation,
169 ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
170 ck::index_t BlockSize,
171 ck::index_t MPerBlock,
172 ck::index_t NPerBlock,
173 ck::index_t KPerBlock,
174 ck::index_t K1,
175 ck::index_t MPerXDL,
176 ck::index_t NPerXDL,
177 ck::index_t MXdlPerWave,
178 ck::index_t NXdlPerWave,
179 typename ABlockTransferThreadClusterLengths_K0_M_K1,
180 typename ABlockTransferThreadClusterArrangeOrder,
181 typename ABlockTransferSrcAccessOrder,
182 ck::index_t ABlockTransferSrcVectorDim,
183 ck::index_t ABlockTransferSrcScalarPerVector,
184 ck::index_t ABlockTransferDstScalarPerVector_K1,
185 bool ABlockLdsAddExtraM,
186 typename BBlockTransferThreadClusterLengths_K0_N_K1,
187 typename BBlockTransferThreadClusterArrangeOrder,
188 typename BBlockTransferSrcAccessOrder,
189 ck::index_t BBlockTransferSrcVectorDim,
190 ck::index_t BBlockTransferSrcScalarPerVector,
191 ck::index_t BBlockTransferDstScalarPerVector_K1,
192 bool BBlockLdsAddExtraN,
193 index_t CShuffleMXdlPerWavePerShuffle,
194 index_t CShuffleNXdlPerWavePerShuffle,
195 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
196 index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
199 index_t NumGroupsToMerge = 1,
200 typename ComputeTypeA = InDataType,
201 typename ComputeTypeB = ComputeTypeA,
202 index_t TransposeTransferSrcScalarPerVector = 1,
203 index_t TransposeTransferDstScalarPerVector = 1>
205 : public DeviceGroupedConvBwdWeight<NDimSpatial,
206 InLayout,
207 WeiLayout,
208 OutLayout,
209 InDataType,
210 WeiDataType,
211 OutDataType,
212 InElementwiseOperation,
213 WeiElementwiseOperation,
214 OutElementwiseOperation,
215 ComputeTypeA,
216 ComputeTypeB>
217{
221
224 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
225 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
226
227 using ADataType = OutDataType;
228 using BDataType = InDataType;
229 using EDataType = WeiDataType;
230
231 // If NGCHW then ADataType must be equal to BDataType
235
236 using AElementwiseOperation = OutElementwiseOperation;
237 using BElementwiseOperation = InElementwiseOperation;
238 using CDEElementwiseOperation = WeiElementwiseOperation;
239
240 // TODO make A/B datatype different
241 using ABDataType = InDataType;
242
243 static constexpr auto I0 = Number<0>{};
244 static constexpr auto I1 = Number<1>{};
245 static constexpr auto I2 = Number<2>{};
246 static constexpr auto I3 = Number<3>{};
247 static constexpr auto I4 = Number<4>{};
248 static constexpr auto I5 = Number<5>{};
249
250 static constexpr auto K1Number = Number<K1>{};
251
252 static constexpr auto conv_to_gemm_transformer_v2 =
254 MPerBlock,
255 NPerBlock,
256 K1Number,
257 KPerBlock / K1Number,
258 NumGroupsToMerge,
259 ConvBackwardWeightSpecialization>{};
260
261 static constexpr auto conv_to_gemm_transformer_v1 =
263 MPerBlock,
264 NPerBlock,
265 K1Number,
266 KPerBlock / K1Number,
267 ConvBackwardWeightSpecialization>{};
268
270 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
272 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
273
274 static constexpr auto conv_ngchw_to_nhwgc_transformer =
276 WeiLayout,
277 OutLayout,
278 NDimSpatial,
279 MPerBlock / ClusterLengthMPerBlock,
280 NPerBlock / ClusterLengthNPerBlock>{};
281
283
284 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
285 static auto GetABCGridDesc()
286 {
287 const ck::index_t dim = 1;
288 const ck::index_t batch = 1;
289 const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
290 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
291 const std::array<ck::index_t, NDimSpatial> params{1, 1};
293 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
294 dim,
295 dim,
296 lengths,
297 lengths,
298 lengths,
299 strides,
300 strides,
301 strides,
302 params,
303 params,
304 params,
305 params,
306 batch);
307 }
308
309 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
310 static auto GetABCGridDesc()
311 {
312 const ck::index_t dim = 1;
313 const ck::index_t batch = 1;
314 const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
315 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
316 const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
318 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
319 dim,
320 dim,
321 lengths,
322 lengths,
323 lengths,
324 strides,
325 strides,
326 strides,
327 params,
328 params,
329 params,
330 params,
331 batch);
332 }
333
334 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
336 {
337 const ck::index_t dim = 1;
338 const ck::index_t batch = 1;
339 const std::array<ck::index_t, NDimSpatial> lengths{1, 1};
340 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1};
341 const std::array<ck::index_t, NDimSpatial> params{1, 1};
343 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim,
344 dim,
345 dim,
346 lengths,
347 lengths,
348 lengths,
349 strides,
350 strides,
351 strides,
352 params,
353 params,
354 params,
355 params,
356 batch)[I2];
357 }
358
359 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
361 {
362 const ck::index_t dim = 1;
363 const ck::index_t batch = 1;
364 const std::array<ck::index_t, NDimSpatial> lengths{1, 1, 1};
365 const std::array<ck::index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
366 const std::array<ck::index_t, NDimSpatial> params{1, 1, 1};
368 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
369 dim,
370 dim,
371 lengths,
372 lengths,
373 lengths,
374 strides,
375 strides,
376 strides,
377 params,
378 params,
379 params,
380 params,
381 batch)[I2];
382 }
383
386 .template MakeNGCHWTransposeDesc<NDimSpatial>({}, {}))>;
389 .template MakeNHWGCTransposeDesc<NDimSpatial>({}, {}))>;
392 .template MakeGKCYXTransposeDesc<NDimSpatial>({}, {}))>;
395 .template MakeGKYXCTransposeDesc<NDimSpatial>({}, {}))>;
396
398
404
405 template <index_t NXdlPerWave_>
410 ADataType,
411 BDataType,
412 AccDataType,
413 AccDataType,
414 AccDataType,
418 GemmSpec,
419 BlockSize,
420 MPerBlock,
421 NPerBlock,
422 KPerBlock,
423 K1,
424 K1,
425 MPerXDL,
426 NPerXDL,
427 MXdlPerWave,
428 NXdlPerWave_,
429 ABlockTransferThreadClusterLengths_K0_M_K1,
430 ABlockTransferThreadClusterArrangeOrder,
431 ABlockTransferSrcAccessOrder,
432 ABlockTransferSrcVectorDim,
433 ABlockTransferSrcScalarPerVector,
434 ABlockTransferDstScalarPerVector_K1,
435 false,
436 ABlockLdsAddExtraM,
437 BBlockTransferThreadClusterLengths_K0_N_K1,
438 BBlockTransferThreadClusterArrangeOrder,
439 BBlockTransferSrcAccessOrder,
440 BBlockTransferSrcVectorDim,
441 BBlockTransferSrcScalarPerVector,
442 BBlockTransferDstScalarPerVector_K1,
443 false,
444 BBlockLdsAddExtraN,
445 CShuffleMXdlPerWavePerShuffle,
446 CShuffleNXdlPerWavePerShuffle,
447 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
448 CBlockTransferScalarPerVector_NWaveNPerXdl,
449 BlkGemmPipeSched,
450 BlkGemmPipelineVer,
451 ComputeTypeA,
452 ComputeTypeB>;
455
457
465 BlockSize,
466 MPerBlock,
467 NPerBlock,
468 MPerBlock / ClusterLengthMPerBlock,
469 NPerBlock / ClusterLengthNPerBlock,
473 I1,
474 I1>;
475 // NPerBlock is used for the first dim which is store dimension
476 // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector).
477 // CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so
478 // it is more flexible to use this dim for store dimension with such scalar
479 // per vector.
487 BlockSize,
488 MPerBlock,
489 NPerBlock,
490 MPerBlock / ClusterLengthMPerBlock,
491 NPerBlock / ClusterLengthNPerBlock,
495 I1,
496 I0>;
497
505 BlockSize,
506 MPerBlock,
507 NPerBlock,
508 MPerBlock / ClusterLengthMPerBlock,
509 NPerBlock / ClusterLengthNPerBlock,
513 I1,
514 I0>;
515
516 // Argument
519 CGridDesc_M_N{}, 1, 1));
520
522 {
523 template <typename GridwiseGemm>
525 {
526 constexpr int dynamic_smem_size = 0;
527 constexpr index_t minimum_occupancy =
528 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
529 int max_occupancy = 0;
530
531 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
532 {
533 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
534 &max_occupancy,
536 GridwiseGemm,
540 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
541 NumGroupsToMerge,
542 true,
544 minimum_occupancy>,
545 BlockSize,
546 dynamic_smem_size));
547 }
548 else
549 {
550 hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
551 &max_occupancy,
553 GridwiseGemm,
557 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
558 NumGroupsToMerge,
559 true,
561 minimum_occupancy>,
562 BlockSize,
563 dynamic_smem_size));
564 }
565 return std::max(1, max_occupancy);
566 }
567
569 {
570 max_occupancy_ = 1;
571 if(get_warp_size() == 64)
572 {
573 if constexpr(NXdlPerWave64 > 0)
574 {
576 }
577 }
578 else
579 {
580 if constexpr(NXdlPerWave32 > 0)
581 {
583 }
584 }
585 }
587 };
588
589 struct Argument : public BaseArgument, public ArgumentSplitK
590 {
591 Argument(const InDataType* p_in_grid,
592 WeiDataType* p_wei_grid,
593 const OutDataType* p_out_grid,
594 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
595 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
596 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
597 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
598 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
599 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
600 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
601 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
602 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
603 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
604 const ck::index_t M01,
605 const ck::index_t N01,
606 InElementwiseOperation in_element_op,
607 WeiElementwiseOperation wei_element_op,
608 OutElementwiseOperation out_element_op,
609 ck::index_t split_k)
610 : p_a_grid_{p_out_grid},
611 p_b_grid_{p_in_grid},
612 p_e_grid_{p_wei_grid},
618 M01_{M01},
619 N01_{N01},
620 a_element_op_{out_element_op},
621 b_element_op_{in_element_op},
622 cde_element_op_{wei_element_op},
623 Conv_G_{b_g_n_c_wis_lengths[0]},
624 Conv_N_{b_g_n_c_wis_lengths[1]},
625 Conv_K_{e_g_k_c_xs_lengths[1]},
626 Conv_C_{b_g_n_c_wis_lengths[2]},
630 conv_filter_strides_{conv_filter_strides},
631 input_left_pads_{input_left_pads},
632 input_right_pads_{input_right_pads}
633 {
634 static ActiveWorkgroupsPerCU active_workgroups_per_cu;
635
638 e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
639 sizeof(AccDataType);
640
641 constexpr index_t spatial_offset = 3;
642 std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
643 end(b_g_n_c_wis_lengths),
645 std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
646 end(e_g_k_c_xs_lengths),
648 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset,
649 end(a_g_n_k_wos_lengths),
651
652 std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
653 conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
654 a_g_n_k_wos_strides);
655 std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
656 conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths,
657 b_g_n_c_wis_strides);
658 std::array<index_t, NDimSpatial + 3> e_g_k_c_xs_strides_transposed =
659 conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths,
660 e_g_k_c_xs_strides);
661
662 if(split_k < 0)
663 {
664 ck::index_t gemmM, gemmN, gemmK;
665 std::tie(gemmM, gemmN, gemmK) =
666 get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
667
668 const auto grid_size = calculate_mn_grid_size<MPerBlock, NPerBlock>(gemmM, gemmN) *
669 Conv_G_ / NumGroupsToMerge;
671 grid_size);
672
673 // Ensure that k_batch_ does not exceed the maximum value
674 // for the GEMM pipeline.
675 const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
676 k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
677
678 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
679 {
680 std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max
681 << std::endl;
682 std::cout << "[SPLIT-K AUTODEDUCE] Final k_batch value: " << k_batch_
683 << std::endl;
684 }
685 }
686 else
687 {
688 k_batch_ = split_k;
689 }
690
691 const auto descs =
693 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
694 Conv_N_,
695 Conv_K_,
696 Conv_C_,
700 b_g_n_c_wis_strides_transposed,
701 e_g_k_c_xs_strides_transposed,
702 a_g_n_k_wos_strides_transposed,
703 conv_filter_strides,
704 conv_filter_dilations,
705 input_left_pads,
706 input_right_pads,
707 k_batch_);
708
709 a_grid_desc_k0_m_k1_ = descs[I0];
710 b_grid_desc_k0_n_k1_ = descs[I1];
711 ce_grid_desc_m_n_ = descs[I2];
712
715 .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
716 Conv_N_,
717 Conv_K_,
718 Conv_C_,
722 b_g_n_c_wis_strides,
723 e_g_k_c_xs_strides,
724 a_g_n_k_wos_strides,
725 conv_filter_strides,
726 conv_filter_dilations,
727 input_left_pads,
728 input_right_pads,
729 k_batch_)[I2];
730
731 const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
732 const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
733
734 // A/B/C Batch Stride
735 compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
736 compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
737 compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0];
743
746 {
748 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
749 a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
751 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
752 a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
753
755 conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc<NDimSpatial>(
756 b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
758 conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc<NDimSpatial>(
759 b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
760
762 conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc<NDimSpatial>(
763 e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
765 conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc<NDimSpatial>(
766 e_g_k_c_xs_lengths, e_g_k_c_xs_strides);
767
769 a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
770
772 b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
773 }
774
779 e_in_transpose_desc_.GetLength(I1)}
781 ce_grid_desc_m_n_.GetLength(I1)};
782 }
783
785 {
786 // Align to 128B
788 sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) *
789 128;
790 }
791
793 {
794 return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
795 }
796
798 {
799 // Align to 128B
800 return math::integer_divide_ceil(sizeof(AccDataType) *
801 ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_,
802 128) *
803 128;
804 }
805
806 std::size_t GetWorkspaceSizeBytes() const
807 {
808 // 1. We need to transpose A and B for NGCHW and NGKHW layouts
809 // 2. If C format is GKCYX then tranpose during second stage.
810 // If C format is GKYXC then just perform second stage.
811 // Due to the fact that E workspace is always needed, we
812 // allocate them as the first part of the workspace.
813 // [EWorkspace, AWorkspace, BWorkspace]
816 {
819 }
820 else
821 {
823 }
824 }
825
829
835
839
844
845 // for computing batch offset
846 ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
847
850
851 OutElementwiseOperation a_element_op_;
852 InElementwiseOperation b_element_op_;
853 WeiElementwiseOperation cde_element_op_;
854
855 // for checking IsSupportedArgument()
860 std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
861 std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
862 std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
863 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
864 const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
865 const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
867 };
868
869 // Invoker
870 struct Invoker : public BaseInvoker
871 {
873
874 void ShowInfo(const Argument& arg)
875 {
876 std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
877 << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
878 << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
879
880 std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
881 << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
882 << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
883
884 std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", "
885 << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
886 }
887
888 template <typename GridwiseGemm>
889 float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
890 {
891 const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
892 const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
893 const index_t GemmK =
894 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
895
896 AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
897
898 const ADataType* p_a_grid = arg.p_a_grid_;
899 const BDataType* p_b_grid = arg.p_b_grid_;
900
903 {
906 p_b_grid =
909 sizeof(BDataType);
910 }
911
912 // nullptr for output, will be set after workspace set
913 typename GridwiseGemm::Argument gemm_arg{
914 p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
915
916 index_t gdx, gdy, gdz;
917 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
918 gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
919
920 float ave_time = 0;
921
922 index_t k_grain = gemm_arg.KBatch * KPerBlock;
923 index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock);
924
925 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
926
927 const auto num_k_per_block =
928 arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
929
930 const auto clear_workspace = [&]() {
931 if(arg.k_batch_ > 1)
932 {
933 hip_check_error(hipMemsetAsync(
934 gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
935 }
936 };
937
938 const auto Run = [&](const auto& kernel) {
939 if(stream_config.flush_cache)
940 {
941 typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
942 ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
943 gemm_arg_,
944 stream_config.rotating_count,
945 gemm_arg_.M * gemm_arg_.K * sizeof(ADataType),
946 gemm_arg_.K * gemm_arg_.N * sizeof(BDataType));
947 rotating_mem.Print();
948
949 auto run_flush_cache = [&]() {
950 // flush icache
952 // rotating mem
953 rotating_mem.Next();
954 clear_workspace();
955 };
956
958 stream_config,
959 run_flush_cache,
960 kernel,
961 dim3(gdx, gdy, gdz),
962 dim3(BlockSize),
963 0,
964 gemm_arg_,
969 num_k_per_block);
970 }
971 else
972 {
974 stream_config,
975 clear_workspace,
976 kernel,
977 dim3(gdx, gdy, gdz),
978 dim3(BlockSize),
979 0,
980 gemm_arg,
985 num_k_per_block);
986 }
987 };
988
989 constexpr index_t minimum_occupancy =
990 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
991
992 if(has_main_k_block_loop)
993 {
994 // Tail number always full
995 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
996 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
997 {
998 if(gemm_arg.KBatch > 1)
999 {
1001 GridwiseGemm,
1006 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1007 NumGroupsToMerge,
1008 true,
1010 minimum_occupancy>;
1011 Run(kernel);
1012 }
1013 else
1014 {
1016 GridwiseGemm,
1021 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1022 NumGroupsToMerge,
1023 true,
1025 minimum_occupancy>;
1026 Run(kernel);
1027 }
1028 }
1029 // Tail number could be One to Seven
1030 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
1031 {
1032 if(gemm_arg.KBatch > 1)
1033 {
1034 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
1035 {
1037 GridwiseGemm,
1042 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1043 NumGroupsToMerge,
1044 true,
1046 minimum_occupancy,
1048 Run(kernel);
1049 }
1050 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1052 {
1054 GridwiseGemm,
1059 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1060 NumGroupsToMerge,
1061 true,
1063 minimum_occupancy,
1065 Run(kernel);
1066 }
1067
1068 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
1069 {
1070 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
1071 {
1073 GridwiseGemm,
1078 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1079 NumGroupsToMerge,
1080 true,
1082 minimum_occupancy,
1084 Run(kernel);
1085 }
1086 }
1087
1088 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
1089 {
1090 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1092 {
1094 GridwiseGemm,
1099 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1100 NumGroupsToMerge,
1101 true,
1103 minimum_occupancy,
1105 Run(kernel);
1106 }
1107 }
1108
1109 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
1110 {
1111 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1113 {
1115 GridwiseGemm,
1120 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1121 NumGroupsToMerge,
1122 true,
1124 minimum_occupancy,
1126 Run(kernel);
1127 }
1128 }
1129
1130 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
1131 {
1132 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1134 {
1136 GridwiseGemm,
1141 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1142 NumGroupsToMerge,
1143 true,
1145 minimum_occupancy,
1147 Run(kernel);
1148 }
1149 }
1150
1151 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
1152 {
1153 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
1154 {
1156 GridwiseGemm,
1161 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1162 NumGroupsToMerge,
1163 true,
1165 minimum_occupancy,
1167 Run(kernel);
1168 }
1169 }
1170
1171 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
1172 {
1173 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1175 {
1177 GridwiseGemm,
1182 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1183 NumGroupsToMerge,
1184 true,
1186 minimum_occupancy,
1188 Run(kernel);
1189 }
1190 }
1191 }
1192 else
1193 {
1194 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
1195 {
1197 GridwiseGemm,
1202 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1203 NumGroupsToMerge,
1204 true,
1206 minimum_occupancy,
1208 Run(kernel);
1209 }
1210 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1212 {
1214 GridwiseGemm,
1219 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1220 NumGroupsToMerge,
1221 true,
1223 minimum_occupancy,
1225 Run(kernel);
1226 }
1227
1228 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
1229 {
1230 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
1231 {
1233 GridwiseGemm,
1238 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1239 NumGroupsToMerge,
1240 true,
1242 minimum_occupancy,
1244 Run(kernel);
1245 }
1246 }
1247
1248 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
1249 {
1250 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1252 {
1254 GridwiseGemm,
1259 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1260 NumGroupsToMerge,
1261 true,
1263 minimum_occupancy,
1265 Run(kernel);
1266 }
1267 }
1268
1269 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
1270 {
1271 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1273 {
1275 GridwiseGemm,
1280 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1281 NumGroupsToMerge,
1282 true,
1284 minimum_occupancy,
1286 Run(kernel);
1287 }
1288 }
1289
1290 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
1291 {
1292 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1294 {
1296 GridwiseGemm,
1301 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1302 NumGroupsToMerge,
1303 true,
1305 minimum_occupancy,
1307 Run(kernel);
1308 }
1309 }
1310
1311 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
1312 {
1313 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
1314 {
1316 GridwiseGemm,
1321 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1322 NumGroupsToMerge,
1323 true,
1325 minimum_occupancy,
1327 Run(kernel);
1328 }
1329 }
1330
1331 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
1332 {
1333 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
1335 {
1337 GridwiseGemm,
1342 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1343 NumGroupsToMerge,
1344 true,
1346 minimum_occupancy,
1348 Run(kernel);
1349 }
1350 }
1351 }
1352 }
1353 // Tail number could be Odd or Even
1354 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
1355 {
1356 if(gemm_arg.KBatch > 1)
1357 {
1358 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1359 {
1361 GridwiseGemm,
1366 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1367 NumGroupsToMerge,
1368 true,
1370 minimum_occupancy,
1372 Run(kernel);
1373 }
1374 else
1375 {
1377 GridwiseGemm,
1382 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1383 NumGroupsToMerge,
1384 true,
1386 minimum_occupancy,
1388 Run(kernel);
1389 }
1390 }
1391 else
1392 {
1393 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1394 {
1396 GridwiseGemm,
1401 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1402 NumGroupsToMerge,
1403 true,
1405 minimum_occupancy,
1407 Run(kernel);
1408 }
1409 else
1410 {
1412 GridwiseGemm,
1417 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1418 NumGroupsToMerge,
1419 true,
1421 minimum_occupancy,
1423 Run(kernel);
1424 }
1425 }
1426 }
1427 else
1428 {
1429 if(gemm_arg.KBatch > 1)
1430 {
1431 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1432 {
1434 GridwiseGemm,
1439 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1440 NumGroupsToMerge,
1441 true,
1443 minimum_occupancy,
1445 Run(kernel);
1446 }
1447 else
1448 {
1450 GridwiseGemm,
1455 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1456 NumGroupsToMerge,
1457 true,
1459 minimum_occupancy,
1461 Run(kernel);
1462 }
1463 }
1464 else
1465 {
1466 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
1467 {
1469 GridwiseGemm,
1474 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1475 NumGroupsToMerge,
1476 true,
1478 minimum_occupancy,
1480 Run(kernel);
1481 }
1482 else
1483 {
1485 GridwiseGemm,
1490 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1491 NumGroupsToMerge,
1492 true,
1494 minimum_occupancy,
1496 Run(kernel);
1497 }
1498 }
1499 }
1500 }
1501 else
1502 {
1503 // Tail number always 1
1504 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
1505 {
1506 if(gemm_arg.KBatch > 1)
1507 {
1509 GridwiseGemm,
1514 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1515 NumGroupsToMerge,
1516 false,
1518 minimum_occupancy>;
1519 Run(kernel);
1520 }
1521 else
1522 {
1524 GridwiseGemm,
1529 ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
1530 NumGroupsToMerge,
1531 false,
1533 minimum_occupancy>;
1534 Run(kernel);
1535 }
1536 }
1537 }
1538
1539 return ave_time;
1540 }
1541
1542 template <typename GridwiseGemm>
1543 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1544 {
1545 float avg_time = 0.f;
1546 auto launch_elementwise_kernel = [&]() {
1547 const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
1548
1549 std::array<index_t, I1> in_out_batch_strides = {
1550 static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
1551
1554 {
1555 const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
1557
1565
1566 return launch_and_time_kernel(stream_config,
1567 kernel,
1568 dim3(grid_size),
1569 dim3(BlockSize),
1570 0,
1573 make_tuple(p_c_grid),
1574 make_tuple(arg.p_e_grid_),
1576 arg.cde_element_op_);
1577 }
1578 else
1579 {
1580 const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
1582 arg.Conv_G_;
1583
1584 const auto kernel =
1592 I1,
1593 I1>;
1594
1595 return launch_and_time_kernel(stream_config,
1596 kernel,
1597 dim3(grid_size),
1598 dim3(BlockSize),
1599 0,
1602 make_tuple(p_c_grid),
1603 make_tuple(arg.p_e_grid_),
1605 arg.cde_element_op_,
1606 arg.Conv_G_,
1607 in_out_batch_strides,
1608 in_out_batch_strides);
1609 }
1610 };
1611
1614 {
1615 const index_t grid_size_a =
1616 arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
1618 const index_t grid_size_b =
1619 arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
1621
1622 ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
1624 BDataType* p_b_out_grid =
1627 sizeof(BDataType);
1628
1629 // Different data type for A and B is not supported
1642 element_wise::PassThrough>;
1643
1644 avg_time += launch_and_time_kernel(stream_config,
1645 kernel_transpose,
1646 dim3(grid_size_a + grid_size_b),
1647 dim3(BlockSize),
1648 0,
1653 make_tuple(arg.p_a_grid_),
1654 make_tuple(arg.p_b_grid_),
1655 make_tuple(p_a_out_grid),
1656 make_tuple(p_b_out_grid),
1659 element_wise::PassThrough{},
1660 grid_size_a);
1661 }
1662
1663 avg_time += RunGemmV3<GridwiseGemm>(arg, stream_config);
1664 avg_time += launch_elementwise_kernel();
1665 return avg_time;
1666 }
1667
1669
1670 float Run(const BaseArgument* p_arg,
1671 const StreamConfig& stream_config = StreamConfig{}) override
1672 {
1673 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1674 }
1675 };
1676
1677 static constexpr bool IsValidCompilationParameter()
1678 {
1679 // TODO: properly implement this check
1680 return true;
1681 }
1682
1683 static bool IsSupportedArgument(const Argument& arg)
1684 {
1685 const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
1686 const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
1687 const index_t GemmK =
1688 arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2);
1689
1691 {
1692 if(!is_tf32_supported())
1693 {
1694 return false;
1695 }
1697 {
1698 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1699 {
1700 std::cout << "ComputeDataType for A and B should be same while using TF32"
1701 << std::endl;
1702 }
1703 return false;
1704 }
1705 }
1706
1707 if(get_warp_size() == 64)
1708 {
1709 if constexpr(NXdlPerWave64 > 0)
1710 {
1711 typename GridwiseGemm64::Argument gemm_arg{
1712 nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
1713
1714 const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1);
1715 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1716 {
1717 if(num_k_loop <= GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages)
1718 {
1719 return false;
1720 }
1721 }
1722 }
1723 else
1724 {
1725 return false;
1726 }
1727 }
1728 else
1729 {
1730 if constexpr(NXdlPerWave32 > 0)
1731 {
1732 typename GridwiseGemm32::Argument gemm_arg{
1733 nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
1734
1735 const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1);
1736 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1737 {
1738 if(num_k_loop <= GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages)
1739 {
1740 return false;
1741 }
1742 }
1743 }
1744 else
1745 {
1746 return false;
1747 }
1748 }
1749
1750 // Check this here, it allows to use other instances from factory even
1751 // if workspace is not allocated
1752 if(!arg.p_workspace_)
1753 {
1754 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1755 {
1756 std::cout << "Warning: Workspace for "
1757 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
1758 "allocated, use SetWorkSpacePointer."
1759 << std::endl;
1760 }
1761 return false;
1762 }
1764 {
1765 return false;
1766 }
1767 if constexpr(NDimSpatial == 2)
1768 {
1771 {
1772 return false;
1773 }
1774 }
1775 else if constexpr(NDimSpatial == 3)
1776 {
1779 {
1780 return false;
1781 }
1782 }
1783 else
1784 {
1785 return false;
1786 }
1787
1788 if constexpr(ConvBackwardWeightSpecialization ==
1790 {
1791 // check if it's 1x1, stride=1 pad = 0 conv
1792 for(int i = 0; i < NDimSpatial; i++)
1793 {
1794 if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1795 arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1796 {
1797 return false;
1798 }
1799 }
1800 }
1801
1802 if constexpr(NumGroupsToMerge > 1)
1803 {
1804 // support only if whole M and N can be proccessed on one block
1805 if(!(GemmM <= MPerBlock && GemmN <= NPerBlock))
1806 {
1807 return false;
1808 }
1809 if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1))
1810 {
1811 return false;
1812 }
1813 if(arg.Conv_G_ % NumGroupsToMerge != 0)
1814 {
1815 return false;
1816 }
1817 }
1818
1819 const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 &&
1820 arg.input_right_pads_[NDimSpatial - 1] == 0;
1821 const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1];
1822 const bool XC_access_allowed = arg.Conv_G_ == 1 &&
1823 (arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 &&
1824 is_w_pad_zero;
1825
1826 if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) &&
1827 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
1828 {
1829 if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 &&
1830 NumGroupsToMerge > 1))
1831 {
1832 return false;
1833 }
1834 if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 &&
1835 NumGroupsToMerge > 1))
1836 {
1837 return false;
1838 }
1839 }
1840
1841 // vector load A/B matrix from global memory
1842 if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1))
1843 {
1844 return false;
1845 }
1846
1847 // vector store C matrix into global memory
1848 if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
1849 {
1850 return false;
1851 }
1852
1855 {
1856 if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
1857 {
1858 return false;
1859 }
1860
1861 if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
1862 {
1863 return false;
1864 }
1865
1866 const index_t input_spatial_acum = ck::accumulate_n<index_t>(
1867 arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
1868 const index_t output_spatial_acum = ck::accumulate_n<index_t>(
1869 arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
1870
1871 if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
1872 {
1873 return false;
1874 }
1875
1876 if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
1877 {
1878 return false;
1879 }
1880
1881 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1882 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1883 arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB))
1884 {
1885 return false;
1886 }
1887 }
1888
1889 constexpr long_index_t TwoGB = (long_index_t{1} << 31);
1890 if(!(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
1891 arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
1892 arg.ce_grid_desc_m_n_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
1893 {
1894 return false;
1895 }
1896
1897 return true;
1898 }
1899
1900 bool IsSupportedArgument(const BaseArgument* p_arg) override
1901 {
1902 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1903 }
1904
1905 static auto
1906 MakeArgument(const InDataType* p_in_grid,
1907 WeiDataType* p_wei_grid,
1908 const OutDataType* p_out_grid,
1909 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1910 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1911 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1912 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1913 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1914 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1915 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1916 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1917 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1918 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1919 InElementwiseOperation in_element_op,
1920 WeiElementwiseOperation wei_element_op,
1921 OutElementwiseOperation out_element_op,
1922 const ck::index_t split_k)
1923 {
1924 return Argument{p_in_grid,
1925 p_wei_grid,
1926 p_out_grid,
1927 b_g_n_c_wis_lengths, // input
1928 b_g_n_c_wis_strides,
1929 e_g_k_c_xs_lengths, // weight
1930 e_g_k_c_xs_strides,
1931 a_g_n_k_wos_lengths, // output
1932 a_g_n_k_wos_strides,
1933 conv_filter_strides,
1934 conv_filter_dilations,
1935 input_left_pads,
1936 input_right_pads,
1937 1,
1938 1,
1939 in_element_op,
1940 wei_element_op,
1941 out_element_op,
1942 split_k};
1943 }
1944
1945 static auto MakeInvoker() { return Invoker{}; }
1946
1947 std::unique_ptr<BaseArgument>
1948 MakeArgumentPointer(const void* p_in_grid,
1949 void* p_wei_grid,
1950 const void* p_out_grid,
1951 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
1952 const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
1953 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
1954 const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
1955 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
1956 const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
1957 const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
1958 const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
1959 const std::array<ck::index_t, NDimSpatial>& input_left_pads,
1960 const std::array<ck::index_t, NDimSpatial>& input_right_pads,
1961 InElementwiseOperation in_element_op,
1962 WeiElementwiseOperation wei_element_op,
1963 OutElementwiseOperation out_element_op,
1964 const ck::index_t split_k) override
1965 {
1966 return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
1967 static_cast<WeiDataType*>(p_wei_grid),
1968 static_cast<const OutDataType*>(p_out_grid),
1969 b_g_n_c_wis_lengths, // input
1970 b_g_n_c_wis_strides,
1971 e_g_k_c_xs_lengths, // weight
1972 e_g_k_c_xs_strides,
1973 a_g_n_k_wos_lengths, // output
1974 a_g_n_k_wos_strides,
1975 conv_filter_strides,
1976 conv_filter_dilations,
1977 input_left_pads,
1978 input_right_pads,
1979 1,
1980 1,
1981 in_element_op,
1982 wei_element_op,
1983 out_element_op,
1984 split_k);
1985 }
1986
1987 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1988 {
1989 return std::make_unique<Invoker>(Invoker{});
1990 }
1991
1992 std::string GetTypeString() const override
1993 {
1994 auto str = std::stringstream();
1995
1996 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
1999
2000 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
2006
2007 // clang-format off
2008 str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"
2009 << "<"
2010 << BlockSize << ", "
2011 << MPerBlock << ", "
2012 << NPerBlock << ", "
2013 << KPerBlock << ", "
2014 << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
2015 << K1 << ", "
2016 << MXdlPerWave << ", "
2017 << NXdlPerWave << ", "
2018 << ABlockTransferSrcScalarPerVector << ", "
2019 << ABlockTransferDstScalarPerVector_K1 << ", "
2020 << BBlockTransferSrcScalarPerVector << ", "
2021 << BBlockTransferDstScalarPerVector_K1 << ", "
2022 << CShuffleMXdlPerWavePerShuffle << ", "
2023 << CShuffleNXdlPerWavePerShuffle << ", "
2024 << CBlockTransferScalarPerVector_NWaveNPerXdl << ", "
2025 << "BlkGemmPipelineScheduler: "
2026 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
2027 << "BlkGemmPipelineVersion: "
2028 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
2029 << NumGroupsToMerge;
2030
2033 str << ", TransposeTransferSrcScalarPerVector: "
2034 << TransposeTransferSrcScalarPerVector <<", "
2035 << "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
2036 }
2037
2038
2039 str << ">";
2040 // clang-format on
2041
2042 return str.str();
2043 }
2044
2045 size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
2046 {
2047 auto arg = dynamic_cast<const Argument*>(p_arg);
2048 if(arg)
2049 {
2050 return arg->GetWorkspaceSizeBytes();
2051 }
2052 else
2053 throw std::runtime_error(
2054 "The argument pointer is not an object of "
2055 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
2056 }
2057
2059 void* p_workspace,
2060 const StreamConfig& = StreamConfig{}) const override
2061 {
2062 auto p_arg_ = dynamic_cast<Argument*>(p_arg);
2063 if(p_arg_)
2064 {
2065 p_arg_->p_workspace_ = p_workspace;
2066 }
2067 else
2068 throw std::runtime_error(
2069 "The argument pointer is not an object of "
2070 "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
2071 }
2072};
2073
2074} // namespace device
2075} // namespace tensor_operation
2076} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:109
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const index_t num_k_per_block)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:51
auto get_bwd_weight_gemm_sizes(const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths)
Definition split_k_utils.hpp:55
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
constexpr bool is_NHWGC_GKYXC_NHWGK()
Definition device_grouped_conv_utils.hpp:40
ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
Definition split_k_utils.hpp:30
GemmSpecialization
Definition gemm_specialization.hpp:11
@ Default
Definition gemm_specialization.hpp:13
constexpr bool is_NDHWGC_GKZYXC_NDHWGK()
Definition device_grouped_conv_utils.hpp:80
constexpr bool is_NGCDHW_NGKDHW()
Definition device_grouped_conv_utils.hpp:112
constexpr bool is_NGCHW_GKCYX_NGKHW()
Definition device_grouped_conv_utils.hpp:64
std::string getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization &s)
Definition convolution_backward_weight_specialization.hpp:21
ck::index_t calculate_mn_grid_size(ck::index_t gemmM, ck::index_t gemmN)
Definition split_k_utils.hpp:84
constexpr bool is_NGCDHW_GKCZYX_NGKDHW()
Definition device_grouped_conv_utils.hpp:104
constexpr bool is_NGCHW_NGKHW()
Definition device_grouped_conv_utils.hpp:72
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition gridwise_elementwise_2d.hpp:221
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) -> decltype(std::accumulate(first, std::next(first, count), init, op))
Definition library/utility/numeric.hpp:11
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition gridwise_elementwise_2d.hpp:61
int64_t long_index_t
Definition ck.hpp:300
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition gridwise_elementwise_2d.hpp:29
Definition ck/stream_config.hpp:10
Definition block_to_ctile_map.hpp:261
Definition gridwise_elementwise_2d.hpp:278
Definition gridwise_gemm_xdl_cshuffle_conv_v3.hpp:66
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:24
Transform conv bwd weight to gemm v2.
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:33
Definition transform_conv_ngchw_to_nhwgc.hpp:31
index_t k_batch_
Definition split_k_arg.hpp:12
Definition device_base.hpp:197
void * p_workspace_
Definition device_base.hpp:204
Definition device_grouped_conv_bwd_weight.hpp:29
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:522
int GetMaxOccupancy()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:524
int max_occupancy_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:586
ActiveWorkgroupsPerCU()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:568
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:590
long_index_t c_space_size_bytes
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:866
NGCHWTransposeDescType b_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:840
std::array< ck::index_t, NDimSpatial > input_spatial_lengths_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:860
const index_t Conv_C_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:859
InElementwiseOperation b_element_op_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:852
OutElementwiseOperation a_element_op_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:851
WeiElementwiseOperation cde_element_op_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:853
const std::array< ck::index_t, NDimSpatial > & input_left_pads_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:864
Block2TileMapElementwise elementwise_block_2_ctile_map_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:836
ComputePtrOffsetOfStridedBatch< I1, I1, I0 > compute_ptr_offset_of_batch_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:846
NHWGCTransposeDescType b_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:841
std::size_t GetWorkspaceETensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:797
CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:833
const BDataType * p_b_grid_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:827
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:837
const index_t Conv_N_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:857
std::array< ck::index_t, NDimSpatial > filter_spatial_lengths_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:861
Argument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, const ck::index_t M01, const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:591
std::size_t GetWorkspaceSizeBytes() const
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:806
std::size_t GetWorkspaceATensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:784
const index_t Conv_K_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:858
const ADataType * p_a_grid_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:826
const std::array< ck::index_t, NDimSpatial > & conv_filter_strides_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:863
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:834
GKYXCTransposeDescType e_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:842
std::array< ck::index_t, NDimSpatial > output_spatial_lengths_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:862
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:831
const std::array< ck::index_t, NDimSpatial > & input_right_pads_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:865
NHWGCTransposeDescType a_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:841
EDataType * p_e_grid_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:828
index_t M01_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:848
std::size_t GetWorkspaceBTensorSizeBytes() const
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:792
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_b_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:838
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:830
GKCYXTransposeDescType e_out_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:843
NGCHWTransposeDescType a_in_transpose_desc_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:840
const index_t Conv_G_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:856
CGridDesc_M_N ce_grid_desc_m_n_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:832
index_t N01_
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:849
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:871
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1670
void ShowInfo(const Argument &arg)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:874
float RunGemmV3(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:889
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1543
DeviceOp::Argument Argument
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:872
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:217
InElementwiseOperation BElementwiseOperation
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:237
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1900
static constexpr auto I4
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:247
void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:2058
InDataType ABDataType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:241
static constexpr auto I5
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:248
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKCYXTransposeDesc< NDimSpatial >({}, {}))> GKCYXTransposeDescType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:390
static constexpr auto I1
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:244
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:224
std::string GetTypeString() const override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1992
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:454
OutDataType ADataType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:227
static constexpr GemmSpecialization GemmSpec
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:282
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeGKYXCTransposeDesc< NDimSpatial >({}, {}))> GKYXCTransposeDescType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:393
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1683
static auto MakeInvoker()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1945
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:2045
static constexpr bool IsValidCompilationParameter()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1677
WeiDataType EDataType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:229
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:397
static constexpr auto I3
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:246
static auto MakeArgument(const InDataType *p_in_grid, WeiDataType *p_wei_grid, const OutDataType *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k)
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1906
GridwiseElementwise< Tuple< CElementwiseGridDesc_M_N >, Tuple< CElementwiseGridDesc_M_N >, Tuple< const AccDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, CDEElementwiseOperation, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 0, 1 >, Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, I1, I1 > GridwiseElementwiseCast
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:458
WeiElementwiseOperation CDEElementwiseOperation
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:238
OutElementwiseOperation AElementwiseOperation
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:236
InDataType BDataType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:228
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMapElementwise
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:456
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:401
static constexpr auto I0
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:243
static constexpr auto K1Number
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:250
GridwiseElementwise< Tuple< GKYXCTransposeDescType >, Tuple< GKCYXTransposeDescType >, Tuple< const AccDataType * >, Tuple< EDataType * >, Block2TileMapElementwise, CDEElementwiseOperation, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 0, 1 >, Sequence< CBlockTransferScalarPerVector_NWaveNPerXdl >, Sequence< 1 >, I1, I0 > GridwiseElementwiseWeightTransposeCast
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:480
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:453
GridwiseGemm_xdl_cshuffle_conv_v3< tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, AccDataType, AccDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, K1, K1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferScalarPerVector_NWaveNPerXdl, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmBase
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:406
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle DeviceOp
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:222
static constexpr auto I2
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:245
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNHWGCTransposeDesc< NDimSpatial >({}, {}))> NHWGCTransposeDescType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:387
remove_cvref_t< decltype(conv_ngchw_to_nhwgc_transformer .template MakeNGCHWTransposeDesc< NDimSpatial >({}, {}))> NGCHWTransposeDescType
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:384
static constexpr index_t ClusterLengthMPerBlock
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:269
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:225
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:400
static constexpr index_t ClusterLengthNPerBlock
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:271
static auto GetElementwiseCGridDesc()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:335
static constexpr auto conv_to_gemm_transformer_v1
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:261
decltype(GridwiseGemm64::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)) CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:517
static auto GetABCGridDesc()
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:285
static constexpr auto conv_ngchw_to_nhwgc_transformer
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:274
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:399
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_grid, void *p_wei_grid, const void *p_out_grid, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, const ck::index_t split_k) override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1948
GridwiseElementwise< Tuple< NGCHWTransposeDescType >, Tuple< NHWGCTransposeDescType >, Tuple< const ADataType * >, Tuple< ADataType * >, Block2TileMapElementwise, element_wise::PassThrough, BlockSize, MPerBlock, NPerBlock, MPerBlock/ClusterLengthMPerBlock, NPerBlock/ClusterLengthNPerBlock, Sequence< 1, 0 >, Sequence< TransposeTransferSrcScalarPerVector >, Sequence< TransposeTransferDstScalarPerVector >, I1, I0 > GridwiseElementwiseTranspose
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:498
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:1987
static constexpr auto conv_to_gemm_transformer_v2
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:252
remove_cvref_t< decltype(GetElementwiseCGridDesc< NDimSpatial >())> CElementwiseGridDesc_M_N
Definition device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp:402
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
#define CK_ENV(name)
Definition utility/env.hpp:129