transform_conv_bwd_weight_to_gemm.hpp Source File

transform_conv_bwd_weight_to_gemm.hpp Source File#

Composable Kernel: transform_conv_bwd_weight_to_gemm.hpp Source File
tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6
13
14namespace ck {
15namespace tensor_operation {
16
17template <index_t NDimSpatial,
18 index_t MPerBlock,
19 index_t NPerBlock,
20 index_t GemmK1Number,
21 index_t K0PerBlock,
22 device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
24{
25 static constexpr auto I0 = Number<0>{};
26 static constexpr auto I1 = Number<1>{};
27
28 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
29 constexpr static auto
31 const index_t Ho,
32 const index_t Wo,
33 const index_t K,
34 const std::array<index_t, NDimSpatial + 3>& output_strides)
35 {
36 const index_t WoStride = output_strides[4];
37 const auto KStride = Number<1>{};
38 return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
39 make_tuple(WoStride, KStride));
40 }
41
42 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
43 constexpr static auto
45 const index_t Hi,
46 const index_t Wi,
47 const index_t C,
48 const std::array<index_t, NDimSpatial + 3>& input_strides)
49 {
50 const index_t NStride = input_strides[1];
51 const index_t HiStride = input_strides[3];
52 const index_t WiStride = input_strides[4];
53 const auto CStride = input_strides[2];
54 if constexpr(ConvBackwardWeightSpecialization ==
56 {
57 return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
58 make_tuple(WiStride, CStride));
59 }
60 else
61 {
62 return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
63 make_tuple(NStride, HiStride, WiStride, CStride));
64 }
65 }
66
67 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
68 constexpr static auto
70 const index_t Y,
71 const index_t X,
72 const index_t C,
73 const std::array<index_t, NDimSpatial + 3>& weights_strides)
74 {
75 const auto CStride = Number<1>{};
76 const auto KStride = weights_strides[1];
77 return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
78 }
79
80 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
81 constexpr static auto
83 const index_t Do,
84 const index_t Ho,
85 const index_t Wo,
86 const index_t K,
87 const std::array<index_t, NDimSpatial + 3>& output_strides)
88 {
89 const index_t WoStride = output_strides[5];
90 const auto KStride = Number<1>{};
91 return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
92 make_tuple(WoStride, KStride));
93 }
94
95 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
96 constexpr static auto
98 const index_t Di,
99 const index_t Hi,
100 const index_t Wi,
101 const index_t C,
102 const std::array<index_t, NDimSpatial + 3>& input_strides)
103 {
104 const index_t NStride = input_strides[1];
105 const index_t DiStride = input_strides[3];
106 const index_t HiStride = input_strides[4];
107 const index_t WiStride = input_strides[5];
108 const auto CStride = input_strides[2];
109 if constexpr(ConvBackwardWeightSpecialization ==
111 {
112 return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
113 make_tuple(WiStride, CStride));
114 }
115 else
116 {
118 make_tuple(N, Di, Hi, Wi, C),
119 make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
120 }
121 }
122
123 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
124 constexpr static auto
126 const index_t Z,
127 const index_t Y,
128 const index_t X,
129 const index_t C,
130 const std::array<index_t, NDimSpatial + 3>& weights_strides)
131 {
132 const auto CStride = Number<1>{};
133 const auto KStride = weights_strides[1];
134 return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
135 make_tuple(KStride, CStride));
136 }
137
138 template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
140 const index_t N,
141 const index_t K,
142 const index_t C,
143 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
144 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
145 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
146 const std::array<index_t, NDimSpatial + 3>& /* input_strides */,
147 const std::array<index_t, NDimSpatial + 3>& /* weights_strides */,
148 const std::array<index_t, NDimSpatial + 3>& /* output_strides */,
149 const std::array<index_t, NDimSpatial>& conv_filter_strides,
150 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
151 const std::array<index_t, NDimSpatial>& input_left_pads,
152 const std::array<index_t, NDimSpatial>& input_right_pads,
153 const index_t batch_k)
154 {
155 using namespace ck;
156
157 const index_t Wi = input_spatial_lengths[0];
158 const index_t Wo = output_spatial_lengths[0];
159 const index_t X = filter_spatial_lengths[0];
160 const index_t ConvStrideW = conv_filter_strides[0];
161 const index_t ConvDilationW = conv_filter_dilations[0];
162 const index_t InLeftPadW = input_left_pads[0];
163 const index_t InRightPadW = input_right_pads[0];
164
165 const index_t GemmKTotal = N * Wo;
166 const index_t GemmM = K;
167 const index_t GemmN = C * X;
168
169 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
170 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
171
172 const index_t GemmKBatch = batch_k;
173 const index_t GemmK0 =
174 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
175 K0PerBlock;
176 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
177
178 if constexpr(ConvBackwardWeightSpecialization ==
180 {
181 // A: output tensor
182 const auto out_gemmktotal_gemmm_grid_desc =
184
185 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
186 out_gemmktotal_gemmm_grid_desc,
187 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
191
192 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
193 out_gemmkpad_gemmm_grid_desc,
194 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
195 make_right_pad_transform(GemmM, PadGemmM)),
198
199 // B: input tensor
200 const auto in_gemmktotal_gemmn_grid_desc =
202
203 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
204 in_gemmktotal_gemmn_grid_desc,
205 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
209
210 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
211 in_gemmkpad_gemmn_grid_desc,
212 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
213 make_right_pad_transform(GemmN, PadGemmN)),
216
217 // C: weight tensor
218 const auto wei_gemmm_gemmn_grid_desc =
220
221 // Padd
222 const auto wei_gemmm_gemmn_pad_grid_desc =
223 transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
224 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
225 make_right_pad_transform(GemmN, PadGemmN)),
228
229 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
230 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
231 wei_gemmm_gemmn_pad_grid_desc);
232 }
233 else
234 {
235 const auto out_gemmktotal_gemmm_grid_desc =
237 const auto in_n_wi_c_grid_desc =
239
240 // A: output tensor
241 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
242 out_gemmktotal_gemmm_grid_desc,
243 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
247
248 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
249 out_gemmkpad_gemmm_grid_desc,
250 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
251 make_right_pad_transform(GemmM, PadGemmM)),
254
255 // B: input tensor
256 const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
257 in_n_wi_c_grid_desc,
259 make_pad_transform(Wi, InLeftPadW, InRightPadW),
263
264 const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
265 in_n_wip_c_grid_desc,
268 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
272
273 const auto in_gemmktotal_gemmn_grid_desc =
274 transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
279
280 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
281 in_gemmktotal_gemmn_grid_desc,
282 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
286
287 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
288 in_gemmkpad_gemmn_grid_desc,
289 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
290 make_right_pad_transform(GemmN, PadGemmN)),
293
294 // C: weight tensor
295 const auto wei_gemmm_gemmn_grid_desc =
297
298 // Padd
299 const auto wei_gemmm_gemmn_pad_grid_desc =
300 transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
301 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
302 make_right_pad_transform(GemmN, PadGemmN)),
305
306 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
307 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
308 wei_gemmm_gemmn_pad_grid_desc);
309 }
310 }
311
312 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
314 const index_t N,
315 const index_t K,
316 const index_t C,
317 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
318 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
319 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
320 const std::array<index_t, NDimSpatial + 3>& input_strides,
321 const std::array<index_t, NDimSpatial + 3>& weights_strides,
322 const std::array<index_t, NDimSpatial + 3>& output_strides,
323 const std::array<index_t, NDimSpatial>& conv_filter_strides,
324 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
325 const std::array<index_t, NDimSpatial>& input_left_pads,
326 const std::array<index_t, NDimSpatial>& input_right_pads,
327 const index_t batch_k)
328 {
329 using namespace ck;
330
331 const index_t Hi = input_spatial_lengths[0];
332 const index_t Wi = input_spatial_lengths[1];
333
334 const index_t Ho = output_spatial_lengths[0];
335 const index_t Wo = output_spatial_lengths[1];
336
337 const index_t Y = filter_spatial_lengths[0];
338 const index_t X = filter_spatial_lengths[1];
339
340 const index_t ConvStrideH = conv_filter_strides[0];
341 const index_t ConvStrideW = conv_filter_strides[1];
342
343 const index_t ConvDilationH = conv_filter_dilations[0];
344 const index_t ConvDilationW = conv_filter_dilations[1];
345
346 const index_t InLeftPadH = input_left_pads[0];
347 const index_t InLeftPadW = input_left_pads[1];
348
349 const index_t InRightPadH = input_right_pads[0];
350 const index_t InRightPadW = input_right_pads[1];
351
352 const index_t GemmKTotal = N * Ho * Wo;
353 const index_t GemmM = K;
354 const index_t GemmN = C * X * Y;
355
356 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
357 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
358
359 const index_t GemmKBatch = batch_k;
360 const index_t GemmK0 =
361 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
362 K0PerBlock;
363 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
364
365 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
366 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
367 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
368
369 if constexpr(ConvBackwardWeightSpecialization ==
371 {
372 // A: output tensor
373 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
374 out_grid_desc,
375 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
379
380 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
381 out_gemmkpad_gemmm_grid_desc,
382 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
383 make_right_pad_transform(GemmM, PadGemmM)),
386
387 // B: input tensor
388 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
389 in_grid_desc,
390 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
394
395 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
396 in_gemmkpad_gemmn_grid_desc,
397 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
398 make_right_pad_transform(GemmN, PadGemmN)),
401
402 // Padd
403 const auto wei_gemmm_gemmn_pad_grid_desc =
404 transform_tensor_descriptor(wei_grid_desc,
405 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
406 make_right_pad_transform(GemmN, PadGemmN)),
409
410 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
411 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
412 wei_gemmm_gemmn_pad_grid_desc);
413 }
414 else
415 {
416 // A: output tensor
417 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
418 out_grid_desc,
419 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
423
424 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
425 out_gemmkpad_gemmm_grid_desc,
426 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
427 make_right_pad_transform(GemmM, PadGemmM)),
430
431 // B: input tensor
432 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
433 in_grid_desc,
435 make_pad_transform(Hi, InLeftPadH, InRightPadH),
436 make_pad_transform(Wi, InLeftPadW, InRightPadW),
440
441 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
442 in_n_hip_wip_c_grid_desc,
445 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
446 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
450
451 const auto in_gemmktotal_gemmn_grid_desc =
452 transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
454 make_merge_transform(make_tuple(N, Ho, Wo))),
457
458 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
459 in_gemmktotal_gemmn_grid_desc,
460 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
464
465 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
466 in_gemmkpad_gemmn_grid_desc,
467 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
468 make_right_pad_transform(GemmN, PadGemmN)),
471
472 // Padd
473 const auto wei_gemmm_gemmn_pad_grid_desc =
474 transform_tensor_descriptor(wei_grid_desc,
475 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
476 make_right_pad_transform(GemmN, PadGemmN)),
479
480 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
481 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
482 wei_gemmm_gemmn_pad_grid_desc);
483 }
484 }
485
486 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
488 const index_t N,
489 const index_t K,
490 const index_t C,
491 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
492 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
493 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
494 const std::array<index_t, NDimSpatial + 3>& input_strides,
495 const std::array<index_t, NDimSpatial + 3>& weights_strides,
496 const std::array<index_t, NDimSpatial + 3>& output_strides,
497 const std::array<index_t, NDimSpatial>& conv_filter_strides,
498 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
499 const std::array<index_t, NDimSpatial>& input_left_pads,
500 const std::array<index_t, NDimSpatial>& input_right_pads,
501 const index_t batch_k)
502 {
503 using namespace ck;
504
505 const index_t Di = input_spatial_lengths[0];
506 const index_t Hi = input_spatial_lengths[1];
507 const index_t Wi = input_spatial_lengths[2];
508
509 const index_t Do = output_spatial_lengths[0];
510 const index_t Ho = output_spatial_lengths[1];
511 const index_t Wo = output_spatial_lengths[2];
512
513 const index_t Z = filter_spatial_lengths[0];
514 const index_t Y = filter_spatial_lengths[1];
515 const index_t X = filter_spatial_lengths[2];
516
517 const index_t ConvStrideD = conv_filter_strides[0];
518 const index_t ConvStrideH = conv_filter_strides[1];
519 const index_t ConvStrideW = conv_filter_strides[2];
520
521 const index_t ConvDilationD = conv_filter_dilations[0];
522 const index_t ConvDilationH = conv_filter_dilations[1];
523 const index_t ConvDilationW = conv_filter_dilations[2];
524
525 const index_t InLeftPadD = input_left_pads[0];
526 const index_t InLeftPadH = input_left_pads[1];
527 const index_t InLeftPadW = input_left_pads[2];
528
529 const index_t InRightPadD = input_right_pads[0];
530 const index_t InRightPadH = input_right_pads[1];
531 const index_t InRightPadW = input_right_pads[2];
532
533 const index_t GemmKTotal = N * Do * Ho * Wo;
534 const index_t GemmM = K;
535 const index_t GemmN = C * Z * X * Y;
536
537 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
538 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
539
540 const index_t GemmKBatch = batch_k;
541 const index_t GemmK0 =
542 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
543 K0PerBlock;
544 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
545
546 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
547 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
548 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
549
550 if constexpr(ConvBackwardWeightSpecialization ==
552 {
553 // A: output tensor
554 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
555 out_grid_desc,
556 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
560
561 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
562 out_gemmkpad_gemmm_grid_desc,
563 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
564 make_right_pad_transform(GemmM, PadGemmM)),
567
568 // B: input tensor
569 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
570 in_grid_desc,
571 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
575
576 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
577 in_gemmkpad_gemmn_grid_desc,
578 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
579 make_right_pad_transform(GemmN, PadGemmN)),
582
583 // Padd
584 const auto wei_gemmm_gemmn_pad_grid_desc =
585 transform_tensor_descriptor(wei_grid_desc,
586 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
587 make_right_pad_transform(GemmN, PadGemmN)),
590
591 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
592 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
593 wei_gemmm_gemmn_pad_grid_desc);
594 }
595 else
596 {
597 // A: output tensor
598 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
599 out_grid_desc,
600 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
604
605 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
606 out_gemmkpad_gemmm_grid_desc,
607 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
608 make_right_pad_transform(GemmM, PadGemmM)),
611
612 // B: input tensor
613 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
614 in_grid_desc,
616 make_pad_transform(Di, InLeftPadD, InRightPadD),
617 make_pad_transform(Hi, InLeftPadH, InRightPadH),
618 make_pad_transform(Wi, InLeftPadW, InRightPadW),
624
625 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
626 in_n_dip_hip_wip_c_grid_desc,
629 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
630 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
631 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
639 Sequence<7>{}));
640
641 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
642 in_n_z_do_y_ho_x_wo_c_grid_desc,
644 make_merge_transform(make_tuple(N, Do, Ho, Wo))),
647
648 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
649 in_gemmktotal_gemmn_grid_desc,
650 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
654
655 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
656 in_gemmkpad_gemmn_grid_desc,
657 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
658 make_right_pad_transform(GemmN, PadGemmN)),
661
662 // Padd
663 const auto wei_gemmm_gemmn_pad_grid_desc =
664 transform_tensor_descriptor(wei_grid_desc,
665 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
666 make_right_pad_transform(GemmN, PadGemmN)),
669
670 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
671 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
672 wei_gemmm_gemmn_pad_grid_desc);
673 }
674 } // function end
675};
676
677} // namespace tensor_operation
678} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:24
static constexpr auto I0
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:25
static constexpr auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:82
static constexpr auto make_in_grid_desc(const index_t N, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:44
static constexpr auto I1
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:26
static constexpr auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:97
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:313
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:139
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:69
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:125
static constexpr auto make_out_grid_desc(const index_t N, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp:30