warp_gemm_attribute_mfma_impl.hpp Source File

warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10// TODO: refactor warp-gemm
11// currently there is a discrepency for vav/vva if we need transpose C/D
12// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13// because we swap the A/B pointer in _impl code (but not known this info here)
14enum class WGAttrCtlEnum
15{
17 Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18 Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19 Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20 Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21 Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22 // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23};
24
25#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26 if constexpr(post_nop_) \
27 { \
28 asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29 "s_nop 3" \
30 : dmod_(c_vec) \
31 : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32 :); \
33 } \
34 else \
35 { \
36 asm volatile(mfma_ " %0, %1, %2, %3\n" \
37 : dmod_(c_vec) \
38 : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39 :); \
40 }
41
42#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43 if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44 { \
45 DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46 } \
47 else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48 { \
49 DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50 } \
51 else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52 { \
53 DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54 } \
55 else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56 { \
57 DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58 } \
59 else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60 { \
61 DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62 }
63
64// F32
65template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67{
68 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69
70 using ADataType = float;
71 using BDataType = float;
72 using CDataType = float;
73
77
78 static constexpr index_t kM = 16;
79 static constexpr index_t kN = 16;
80 static constexpr index_t kK = 4;
81
82 static constexpr index_t kAMBlock = 1;
83 static constexpr index_t kBNBlock = 1;
84
85 static constexpr index_t kAMLane = 16;
86 static constexpr index_t kBNLane = 16;
87 static constexpr index_t kABKLane = 4;
88 static constexpr index_t kABKPerLane = 1;
89
90 static constexpr index_t kCMLane = 4;
91 static constexpr index_t kCNLane = 16;
92 static constexpr index_t kCM0PerLane = 1;
93 static constexpr index_t kCM1PerLane = 4;
94
95 // c_vec += a_vec * b_vec
96 template <bool post_nop_ = false>
98 const AVecType& a_vec,
99 const BVecType& b_vec,
100 bool_constant<post_nop_> = {}) const
101 {
102 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl)
103 else
104 {
105#if defined(__gfx9__)
106 c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
107#else
108 ck_tile::ignore = c_vec;
109 ck_tile::ignore = a_vec;
110 ck_tile::ignore = b_vec;
111#endif
112 }
113 }
114
115 // c_vec = a_vec * b_vec
116 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
117 {
118#if defined(__gfx9__)
119 return bit_cast<CVecType>(
120 __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
121#else
122 ck_tile::ignore = a_vec;
123 ck_tile::ignore = b_vec;
124 return CVecType{0.f};
125#endif
126 }
127};
128
129template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
131{
132 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
133
134 using ADataType = float;
135 using BDataType = float;
136 using CDataType = float;
137
141
142 static constexpr index_t kM = 32;
143 static constexpr index_t kN = 32;
144 static constexpr index_t kK = 2;
145
146 static constexpr index_t kAMBlock = 1;
147 static constexpr index_t kBNBlock = 1;
148
149 static constexpr index_t kAMLane = 32;
150 static constexpr index_t kBNLane = 32;
151 static constexpr index_t kABKLane = 2;
152 static constexpr index_t kABKPerLane = 1;
153
154 static constexpr index_t kCMLane = 2;
155 static constexpr index_t kCNLane = 32;
156 static constexpr index_t kCM0PerLane = 4;
157 static constexpr index_t kCM1PerLane = 4;
158
159 // c_vec += a_vec * b_vec
160 template <bool post_nop_ = false>
162 const AVecType& a_vec,
163 const BVecType& b_vec,
164 bool_constant<post_nop_> = {}) const
165 {
166 DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl)
167 else
168 {
169#if defined(__gfx9__)
170 c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
171#else
172 ck_tile::ignore = c_vec;
173 ck_tile::ignore = a_vec;
174 ck_tile::ignore = b_vec;
175#endif
176 }
177 }
178
179 // c_vec = a_vec * b_vec
180 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
181 {
182#if defined(__gfx9__)
183 return bit_cast<CVecType>(
184 __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
185#else
186 ck_tile::ignore = a_vec;
187 ck_tile::ignore = b_vec;
188 return CVecType{0.f};
189#endif
190 }
191};
192
193// V_MFMA_F32_16x16x32_BF16
194template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
196{
197 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
200 using CDataType = float;
201
205
206 static constexpr index_t kM = 16;
207 static constexpr index_t kN = 16;
208 static constexpr index_t kK = 32;
209
210 static constexpr index_t kAMBlock = 1;
211 static constexpr index_t kBNBlock = 1;
212
213 static constexpr index_t kAMLane = 16;
214 static constexpr index_t kBNLane = 16;
215 static constexpr index_t kABKLane = 4;
216 static constexpr index_t kABKPerLane = 8;
217
218 static constexpr index_t kCMLane = 4;
219 static constexpr index_t kCNLane = 16;
220 static constexpr index_t kCM0PerLane = 1;
221 static constexpr index_t kCM1PerLane = 4;
222
223 // c_vec += a_vec * b_vec
224 template <bool post_nop_ = false>
226 const AVecType& a_vec,
227 const BVecType& b_vec,
228 bool_constant<post_nop_> = {}) const
229 {
230 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
231 else
232 {
233#if defined(__gfx950__)
234 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
235#else
236 ck_tile::ignore = c_vec;
237 ck_tile::ignore = a_vec;
238 ck_tile::ignore = b_vec;
239#endif
240 }
241 }
242
243 // c_vec = a_vec * b_vec
244 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
245 {
246#if defined(__gfx950__)
247 return bit_cast<CVecType>(
248 __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
249#else
250 ck_tile::ignore = a_vec;
251 ck_tile::ignore = b_vec;
252 return CVecType{0.f};
253#endif
254 }
255};
256// FP16
257template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
259{
260 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
263 using CDataType = float;
264
268
269 static constexpr index_t kM = 32;
270 static constexpr index_t kN = 32;
271 static constexpr index_t kK = 8;
272
273 static constexpr index_t kAMBlock = 1;
274 static constexpr index_t kBNBlock = 1;
275
276 static constexpr index_t kAMLane = 32;
277 static constexpr index_t kBNLane = 32;
278 static constexpr index_t kABKLane = 2;
279 static constexpr index_t kABKPerLane = 4;
280
281 static constexpr index_t kCMLane = 2;
282 static constexpr index_t kCNLane = 32;
283 static constexpr index_t kCM0PerLane = 4;
284 static constexpr index_t kCM1PerLane = 4;
285
286 // c_vec += a_vec * b_vec
287 template <bool post_nop_ = false>
289 const AVecType& a_vec,
290 const BVecType& b_vec,
291 bool_constant<post_nop_> = {}) const
292 {
293 DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
294 else
295 {
296#if defined(__gfx9__)
297 c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
298#else
299 ck_tile::ignore = c_vec;
300 ck_tile::ignore = a_vec;
301 ck_tile::ignore = b_vec;
302#endif
303 }
304 }
305
306 // c_vec = a_vec * b_vec
307 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
308 {
309#if defined(__gfx9__)
310 return bit_cast<CVecType>(
311 __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
312#else
313 ck_tile::ignore = a_vec;
314 ck_tile::ignore = b_vec;
315 return CVecType{0.f};
316#endif
317 }
318};
319
320template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
322{
323 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
326 using CDataType = float;
327
331
332 static constexpr index_t kM = 16;
333 static constexpr index_t kN = 16;
334 static constexpr index_t kK = 16;
335
336 static constexpr index_t kAMBlock = 1;
337 static constexpr index_t kBNBlock = 1;
338
339 static constexpr index_t kAMLane = 16;
340 static constexpr index_t kBNLane = 16;
341 static constexpr index_t kABKLane = 4;
342 static constexpr index_t kABKPerLane = 4;
343
344 static constexpr index_t kCMLane = 4;
345 static constexpr index_t kCNLane = 16;
346 static constexpr index_t kCM0PerLane = 1;
347 static constexpr index_t kCM1PerLane = 4;
348
349 // c_vec += a_vec * b_vec
350 template <bool post_nop_ = false>
352 const AVecType& a_vec,
353 const BVecType& b_vec,
354 bool_constant<post_nop_> = {}) const
355 {
356 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
357 else
358 {
359#if defined(__gfx9__)
360 c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
361#else
362 ck_tile::ignore = c_vec;
363 ck_tile::ignore = a_vec;
364 ck_tile::ignore = b_vec;
365#endif
366 }
367 }
368
369 // c_vec = a_vec * b_vec
370 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
371 {
372#if defined(__gfx9__)
373 return bit_cast<CVecType>(
374 __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
375#else
376 ck_tile::ignore = a_vec;
377 ck_tile::ignore = b_vec;
378 return CVecType{0.f};
379#endif
380 }
381};
382
383template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
385{
386 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
389 using CDataType = float;
390
394
395 static constexpr index_t kM = 16;
396 static constexpr index_t kN = 16;
397 static constexpr index_t kK = 32;
398
399 static constexpr index_t kAMBlock = 1;
400 static constexpr index_t kBNBlock = 1;
401
402 static constexpr index_t kAMLane = 16;
403 static constexpr index_t kBNLane = 16;
404 static constexpr index_t kABKLane = 4;
405 static constexpr index_t kABKPerLane = 8;
406
407 static constexpr index_t kCMLane = 4;
408 static constexpr index_t kCNLane = 16;
409 static constexpr index_t kCM0PerLane = 1;
410 static constexpr index_t kCM1PerLane = 4;
411
412 // c_vec += a_vec * b_vec
413 template <bool post_nop_ = false>
415 const AVecType& a_vec,
416 const BVecType& b_vec,
417 bool_constant<post_nop_> = {}) const
418 {
419 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
420 else
421 {
422#if defined(__gfx950__)
423 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
424#else
425 ck_tile::ignore = c_vec;
426 ck_tile::ignore = a_vec;
427 ck_tile::ignore = b_vec;
428#endif
429 }
430 }
431
432 // c_vec = a_vec * b_vec
433 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
434 {
435#if defined(__gfx950__)
436 return bit_cast<CVecType>(
437 __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
438#else
439 ck_tile::ignore = a_vec;
440 ck_tile::ignore = b_vec;
441 return CVecType{0.f};
442#endif
443 }
444};
445
446template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448{
449 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
452 using CDataType = float;
453
457
458 static constexpr index_t kM = 4;
459 static constexpr index_t kN = 64;
460 static constexpr index_t kK = 4;
461
462 static constexpr index_t kAMBlock = 1;
463 static constexpr index_t kBNBlock = 16;
464
465 // we only write down single block (4 threads) thread mapping here
466 static constexpr index_t kAMLane = 4;
467 static constexpr index_t kBNLane = 4;
468 static constexpr index_t kABKLane = 1;
469 static constexpr index_t kABKPerLane = 4;
470
471 static constexpr index_t kCMLane = 1;
472 static constexpr index_t kCNLane = 4;
473 static constexpr index_t kCM0PerLane = 1;
474 static constexpr index_t kCM1PerLane = 4;
475
476 // c_vec += a_vec * b_vec
477 template <bool post_nop_ = false>
479 const AVecType& a_vec,
480 const BVecType& b_vec,
481 bool_constant<post_nop_> = {}) const
482 {
483 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
484 else
485 {
486#if defined(__gfx9__)
487 c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
488#else
489 ignore = c_vec;
490 ignore = a_vec;
491 ignore = b_vec;
492#endif
493 }
494 }
495
496 // c_vec = a_vec * b_vec
497 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
498 {
499#if defined(__gfx9__)
500 return bit_cast<CVecType>(
501 __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
502#else
503 ignore = a_vec;
504 ignore = b_vec;
505 return CVecType{0.f};
506#endif
507 }
508};
509
510template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
512{
513 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
516 using CDataType = float;
517
521
522 static constexpr index_t kM = 64;
523 static constexpr index_t kN = 4;
524 static constexpr index_t kK = 4;
525
526 static constexpr index_t kAMBlock = 16;
527 static constexpr index_t kBNBlock = 1;
528
529 // we only write down single block (4 threads) thread mapping here
530 static constexpr index_t kAMLane = 4;
531 static constexpr index_t kBNLane = 4;
532 static constexpr index_t kABKLane = 1;
533 static constexpr index_t kABKPerLane = 4;
534
535 static constexpr index_t kCMLane = 1;
536 static constexpr index_t kCNLane = 4;
537 static constexpr index_t kCM0PerLane = 1;
538 static constexpr index_t kCM1PerLane = 4;
539
540 // c_vec += a_vec * b_vec
541 template <bool post_nop_ = false>
543 const AVecType& a_vec,
544 const BVecType& b_vec,
545 bool_constant<post_nop_> = {}) const
546 {
547 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
548 else
549 {
550#if defined(__gfx9__)
551 c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
552#else
553 ignore = c_vec;
554 ignore = a_vec;
555 ignore = b_vec;
556#endif
557 }
558 }
559
560 // c_vec = a_vec * b_vec
561 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
562 {
563#if defined(__gfx9__)
564 return bit_cast<CVecType>(
565 __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
566#else
567 ignore = a_vec;
568 ignore = b_vec;
569 return CVecType{0.f};
570#endif
571 }
572};
573
574// Bf16
575template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
577{
578 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
581 using CDataType = float;
582
586
587 static constexpr index_t kM = 32;
588 static constexpr index_t kN = 32;
589 static constexpr index_t kK = 8;
590
591 static constexpr index_t kAMBlock = 1;
592 static constexpr index_t kBNBlock = 1;
593
594 static constexpr index_t kAMLane = 32;
595 static constexpr index_t kBNLane = 32;
596 static constexpr index_t kABKLane = 2;
597 static constexpr index_t kABKPerLane = 4;
598
599 static constexpr index_t kCMLane = 2;
600 static constexpr index_t kCNLane = 32;
601 static constexpr index_t kCM0PerLane = 4;
602 static constexpr index_t kCM1PerLane = 4;
603
604 // c_vec += a_vec * b_vec
605 template <bool post_nop_ = false>
607 const AVecType& a_vec,
608 const BVecType& b_vec,
609 bool_constant<post_nop_> = {}) const
610 {
611 DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
612 else
613 {
614#if defined(__gfx90a__) || defined(__gfx94__)
615 c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
616#elif defined(__gfx908__)
617 static_for<0, 2, 1>{}([&](auto k) {
618 c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
619 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
620 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
621 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
622 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
623 c_vec,
624 0,
625 0,
626 0);
627 });
628#else
629 ck_tile::ignore = c_vec;
630 ck_tile::ignore = a_vec;
631 ck_tile::ignore = b_vec;
632#endif
633 }
634 }
635
636 // c_vec = a_vec * b_vec
637 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
638 {
639#if defined(__gfx90a__) || defined(__gfx94__)
640 return bit_cast<CVecType>(
641 __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
642#elif defined(__gfx908__)
643 CVecType c_vec{0.f};
644 static_for<0, 2, 1>{}([&](auto k) {
645 c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
646 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
647 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
648 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
649 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
650 c_vec,
651 0,
652 0,
653 0);
654 });
655 return c_vec;
656#else
657 ck_tile::ignore = a_vec;
658 ck_tile::ignore = b_vec;
659 return CVecType{0.f};
660#endif
661 }
662};
663
664template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
666{
667 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
670 using CDataType = float;
671
675
676 static constexpr index_t kM = 16;
677 static constexpr index_t kN = 16;
678 static constexpr index_t kK = 16;
679
680 static constexpr index_t kAMBlock = 1;
681 static constexpr index_t kBNBlock = 1;
682
683 static constexpr index_t kAMLane = 16;
684 static constexpr index_t kBNLane = 16;
685 static constexpr index_t kABKLane = 4;
686 static constexpr index_t kABKPerLane = 4;
687
688 static constexpr index_t kCMLane = 4;
689 static constexpr index_t kCNLane = 16;
690 static constexpr index_t kCM0PerLane = 1;
691 static constexpr index_t kCM1PerLane = 4;
692
693 // c_vec += a_vec * b_vec
694 template <bool post_nop_ = false>
696 const AVecType& a_vec,
697 const BVecType& b_vec,
698 bool_constant<post_nop_> = {}) const
699 {
700 DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
701 {
702#if defined(__gfx90a__) || defined(__gfx94__)
703 c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
704#elif defined(__gfx908__)
705 static_for<0, 2, 1>{}([&](auto k) {
706 c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
707 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
708 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
709 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
710 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
711 c_vec,
712 0,
713 0,
714 0);
715 });
716#else
717 ck_tile::ignore = c_vec;
718 ck_tile::ignore = a_vec;
719 ck_tile::ignore = b_vec;
720#endif
721 }
722 }
723
724 // c_vec = a_vec * b_vec
725 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
726 {
727#if defined(__gfx90a__) || defined(__gfx94__)
728 return bit_cast<CVecType>(
729 __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
730#elif defined(__gfx908__)
731 CVecType c_vec{0.f};
732 static_for<0, 2, 1>{}([&](auto k) {
733 c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
734 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
735 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
736 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
737 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
738 c_vec,
739 0,
740 0,
741 0);
742 });
743 return c_vec;
744#else
745 ck_tile::ignore = a_vec;
746 ck_tile::ignore = b_vec;
747 return CVecType{0.f};
748#endif
749 }
750};
751
752template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
754{
755 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
758 using CDataType = float;
759
763
764 static constexpr index_t kM = 4;
765 static constexpr index_t kN = 64;
766 static constexpr index_t kK = 4;
767
768 static constexpr index_t kAMBlock = 1;
769 static constexpr index_t kBNBlock = 16;
770
771 // we only write down single block (4 threads) thread mapping here
772 static constexpr index_t kAMLane = 4;
773 static constexpr index_t kBNLane = 4;
774 static constexpr index_t kABKLane = 1;
775 static constexpr index_t kABKPerLane = 4;
776
777 static constexpr index_t kCMLane = 1;
778 static constexpr index_t kCNLane = 4;
779 static constexpr index_t kCM0PerLane = 1;
780 static constexpr index_t kCM1PerLane = 4;
781
782 // c_vec += a_vec * b_vec
783 template <bool post_nop_ = false>
785 const AVecType& a_vec,
786 const BVecType& b_vec,
787 bool_constant<post_nop_> = {}) const
788 {
789 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
790 else
791 {
792#if defined(__gfx90a__) || defined(__gfx94__)
793 c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
794#elif defined(__gfx908__)
795 static_for<0, 2, 1>{}([&](auto k) {
796 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
797 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
798 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
799 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
800 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
801 c_vec,
802 0,
803 0,
804 0);
805 });
806#else
807 ignore = c_vec;
808 ignore = a_vec;
809 ignore = b_vec;
810#endif
811 }
812 }
813
814 // c_vec = a_vec * b_vec
815 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
816 {
817#if defined(__gfx90a__) || defined(__gfx94__)
818 return bit_cast<CVecType>(
819 __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
820#elif defined(__gfx908__)
821 CVecType c_vec{0.f};
822 static_for<0, 2, 1>{}([&](auto k) {
823 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
824 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
825 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
826 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
827 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
828 c_vec,
829 0,
830 0,
831 0);
832 });
833 return c_vec;
834#else
835 ignore = a_vec;
836 ignore = b_vec;
837 return CVecType{0.f};
838#endif
839 }
840};
841
842template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
844{
845 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
848 using CDataType = float;
849
853
854 static constexpr index_t kM = 64;
855 static constexpr index_t kN = 4;
856 static constexpr index_t kK = 4;
857
858 static constexpr index_t kAMBlock = 16;
859 static constexpr index_t kBNBlock = 1;
860
861 // we only write down single block (4 threads) thread mapping here
862 static constexpr index_t kAMLane = 4;
863 static constexpr index_t kBNLane = 4;
864 static constexpr index_t kABKLane = 1;
865 static constexpr index_t kABKPerLane = 4;
866
867 static constexpr index_t kCMLane = 1;
868 static constexpr index_t kCNLane = 4;
869 static constexpr index_t kCM0PerLane = 1;
870 static constexpr index_t kCM1PerLane = 4;
871
872 // c_vec += a_vec * b_vec
873 template <bool post_nop_ = false>
875 const AVecType& a_vec,
876 const BVecType& b_vec,
877 bool_constant<post_nop_> = {}) const
878 {
879 DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
880 else
881 {
882#if defined(__gfx90a__) || defined(__gfx94__)
883 c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
884#elif defined(__gfx908__)
885 static_for<0, 2, 1>{}([&](auto k) {
886 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
887 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
888 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
889 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
890 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
891 c_vec,
892 0,
893 0,
894 0);
895 });
896#else
897 ignore = c_vec;
898 ignore = a_vec;
899 ignore = b_vec;
900#endif
901 }
902 }
903
904 // c_vec = a_vec * b_vec
905 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
906 {
907#if defined(__gfx90a__) || defined(__gfx94__)
908 return bit_cast<CVecType>(
909 __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
910#elif defined(__gfx908__)
911 CVecType c_vec{0.f};
912 static_for<0, 2, 1>{}([&](auto k) {
913 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
914 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
915 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
916 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
917 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
918 c_vec,
919 0,
920 0,
921 0);
922 });
923 return c_vec;
924#else
925 ignore = a_vec;
926 ignore = b_vec;
927 return CVecType{0.f};
928#endif
929 }
930};
931
932// gfx950
933template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
935{
936 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
939 using CDataType = float;
940
944
945 static constexpr index_t kM = 32;
946 static constexpr index_t kN = 32;
947 static constexpr index_t kK = 16;
948
949 static constexpr index_t kAMBlock = 1;
950 static constexpr index_t kBNBlock = 1;
951
952 static constexpr index_t kAMLane = 32;
953 static constexpr index_t kBNLane = 32;
954 static constexpr index_t kABKLane = 2;
955 static constexpr index_t kABKPerLane = 8;
956
957 static constexpr index_t kCMLane = 2;
958 static constexpr index_t kCNLane = 32;
959 static constexpr index_t kCM0PerLane = 4;
960 static constexpr index_t kCM1PerLane = 4;
961
962 // c_vec += a_vec * b_vec
963 template <bool post_nop_ = false>
965 const AVecType& a_vec,
966 const BVecType& b_vec,
967 bool_constant<post_nop_> = {}) const
968 {
969 DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
970 else
971 {
972#if defined(__gfx950__)
973 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
974#elif defined(__gfx90a__) || defined(__gfx94__)
975 static_for<0, 2, 1>{}([&](auto k) {
976 c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
977 reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
978 .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
979 reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
980 .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
981 c_vec,
982 0,
983 0,
984 0);
985 });
986#elif defined(__gfx908__)
987 static_for<0, 4, 1>{}([&](auto k) {
988 c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
989 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
990 .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
991 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
992 .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
993 c_vec,
994 0,
995 0,
996 0);
997 });
998#else
999 ck_tile::ignore = c_vec;
1000 ck_tile::ignore = a_vec;
1001 ck_tile::ignore = b_vec;
1002#endif
1003 }
1004 }
1005
1006 // c_vec = a_vec * b_vec
1007 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1008 {
1009#if defined(__gfx950__)
1010 return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1011#elif defined(__gfx90a__) || defined(__gfx94__)
1012 CVecType c_vec{0.f};
1013 static_for<0, 2, 1>{}([&](auto k) {
1014 c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
1015 reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1016 .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1017 reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1018 .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1019 c_vec,
1020 0,
1021 0,
1022 0);
1023 });
1024 return c_vec;
1025#elif defined(__gfx908__)
1026 CVecType c_vec{0.f};
1027 static_for<0, 4, 1>{}([&](auto k) {
1028 c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
1029 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1030 .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1031 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1032 .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1033 c_vec,
1034 0,
1035 0,
1036 0);
1037 });
1038 return c_vec;
1039#else
1040 ck_tile::ignore = a_vec;
1041 ck_tile::ignore = b_vec;
1042 return CVecType{0.f};
1043#endif
1044 }
1045};
1046
1047template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1049{
1050 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1053 using CDataType = float;
1054
1058
1059 static constexpr index_t kM = 32;
1060 static constexpr index_t kN = 32;
1061 static constexpr index_t kK = 16;
1062
1063 static constexpr index_t kAMBlock = 1;
1064 static constexpr index_t kBNBlock = 1;
1065
1066 static constexpr index_t kAMLane = 32;
1067 static constexpr index_t kBNLane = 32;
1068 static constexpr index_t kABKLane = 2;
1069 static constexpr index_t kABKPerLane = 8;
1070
1071 static constexpr index_t kCMLane = 2;
1072 static constexpr index_t kCNLane = 32;
1073 static constexpr index_t kCM0PerLane = 4;
1074 static constexpr index_t kCM1PerLane = 4;
1075
1076 // c_vec += a_vec * b_vec
1077 template <bool post_nop_ = false>
1079 const AVecType& a_vec,
1080 const BVecType& b_vec,
1081 bool_constant<post_nop_> = {}) const
1082 {
1083 DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
1084 else
1085 {
1086#if defined(__gfx950__)
1087 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
1088#elif defined(__gfx90a__) || defined(__gfx94__)
1089 static_for<0, 2, 1>{}([&](auto k) {
1090 c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1091 reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1092 .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1093 reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1094 .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1095 c_vec,
1096 0,
1097 0,
1098 0);
1099 });
1100#elif defined(__gfx908__)
1101 static_for<0, 4, 1>{}([&](auto k) {
1102 c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1103 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1104 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1105 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1106 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1107 c_vec,
1108 0,
1109 0,
1110 0);
1111 });
1112#else
1113 ck_tile::ignore = c_vec;
1114 ck_tile::ignore = a_vec;
1115 ck_tile::ignore = b_vec;
1116#endif
1117 }
1118 }
1119
1120 // c_vec = a_vec * b_vec
1121 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1122 {
1123#if defined(__gfx950__)
1124 return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1125#elif defined(__gfx90a__) || defined(__gfx94__)
1126 CVecType c_vec{0.f};
1127 static_for<0, 2, 1>{}([&](auto k) {
1128 c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1129 reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1130 .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1131 reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1132 .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1133 c_vec,
1134 0,
1135 0,
1136 0);
1137 });
1138 return c_vec;
1139#elif defined(__gfx908__)
1140 CVecType c_vec{0.f};
1141 static_for<0, 4, 1>{}([&](auto k) {
1142 c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1143 reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1144 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1145 reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1146 .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1147 c_vec,
1148 0,
1149 0,
1150 0);
1151 });
1152 return c_vec;
1153#else
1154 ck_tile::ignore = a_vec;
1155 ck_tile::ignore = b_vec;
1156 return CVecType{0.f};
1157#endif
1158 }
1159};
1160
1161// FP8
1162template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1164{
1165 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1166 using ADataType = AType_;
1167 using BDataType = BType_;
1168 using CDataType = float;
1169
1173
1174 static constexpr index_t kM = 16;
1175 static constexpr index_t kN = 16;
1176 static constexpr index_t kK = 32;
1177
1178 static constexpr index_t kAMBlock = 1;
1179 static constexpr index_t kBNBlock = 1;
1180
1181 static constexpr index_t kAMLane = 16;
1182 static constexpr index_t kBNLane = 16;
1183 static constexpr index_t kABKLane = 4;
1184 static constexpr index_t kABKPerLane = 8;
1185
1186 static constexpr index_t kCMLane = 4;
1187 static constexpr index_t kCNLane = 16;
1188 static constexpr index_t kCM0PerLane = 1;
1189 static constexpr index_t kCM1PerLane = 4;
1190
1191 // c_vec += a_vec * b_vec
1192 template <bool post_nop_ = false>
1194 const AVecType& a_vec,
1195 const BVecType& b_vec,
1196 bool_constant<post_nop_> = {}) const
1197 {
1198 if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1199 {
1200 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1201 {
1202 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1203 }
1204 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1205 {
1206 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1207 }
1208 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1209 {
1210 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1211 }
1212 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1213 {
1214 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1215 }
1216 }
1217 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1218 {
1219 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1220 {
1221 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1222 }
1223 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1224 {
1225 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1226 }
1227 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1228 {
1229 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1230 }
1231 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1232 {
1233 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1234 }
1235 }
1236 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1237 {
1238 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1239 {
1240 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1241 }
1242 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1243 {
1244 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1245 }
1246 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1247 {
1248 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1249 }
1250 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1251 {
1252 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1253 }
1254 }
1255 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1256 {
1257 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1258 {
1259 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1260 }
1261 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1262 {
1263 DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1264 }
1265 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1266 {
1267 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1268 }
1269 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1270 {
1271 DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1272 }
1273 }
1274 else
1275 {
1276#if defined(__gfx94__) or defined(__gfx95__)
1277 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1278 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1279 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1280 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1281 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1282 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1283 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1284 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1285 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1286 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1287 c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1288 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1289#else
1290 ck_tile::ignore = c_vec;
1291 ck_tile::ignore = a_vec;
1292 ck_tile::ignore = b_vec;
1293#endif
1294 }
1295 }
1296
1297 // c_vec = a_vec * b_vec
1298 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1299 {
1300#if defined(__gfx94__) or defined(__gfx95__)
1301 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1302 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1303 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1304 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1305 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1306 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1307 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1308 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1309 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1310 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1311 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1312 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1313#else
1314 ck_tile::ignore = a_vec;
1315 ck_tile::ignore = b_vec;
1316 return CVecType{0.f};
1317#endif
1318 }
1319};
1320
1321template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1323{
1324 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1325 using ADataType = AType_;
1326 using BDataType = BType_;
1327 using CDataType = float;
1328
1332
1333 static constexpr index_t kM = 32;
1334 static constexpr index_t kN = 32;
1335 static constexpr index_t kK = 16;
1336
1337 static constexpr index_t kAMBlock = 1;
1338 static constexpr index_t kBNBlock = 1;
1339
1340 static constexpr index_t kAMLane = 32;
1341 static constexpr index_t kBNLane = 32;
1342 static constexpr index_t kABKLane = 2;
1343 static constexpr index_t kABKPerLane = 8;
1344
1345 static constexpr index_t kCMLane = 2;
1346 static constexpr index_t kCNLane = 32;
1347 static constexpr index_t kCM0PerLane = 4;
1348 static constexpr index_t kCM1PerLane = 4;
1349
1350 // c_vec += a_vec * b_vec
1351 template <bool post_nop_ = false>
1353 const AVecType& a_vec,
1354 const BVecType& b_vec,
1355 bool_constant<post_nop_> = {}) const
1356 {
1357 if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1358 {
1359 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1360 {
1361 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1362 }
1363 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1364 {
1365 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1366 }
1367 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1368 {
1369 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1370 }
1371 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1372 {
1373 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1374 }
1375 }
1376 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1377 {
1378 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1379 {
1380 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1381 }
1382 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1383 {
1384 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1385 }
1386 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1387 {
1388 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1389 }
1390 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1391 {
1392 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1393 }
1394 }
1395 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1396 {
1397 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1398 {
1399 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1400 }
1401 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1402 {
1403 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1404 }
1405 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1406 {
1407 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1408 }
1409 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1410 {
1411 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1412 }
1413 }
1414 else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1415 {
1416 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1417 {
1418 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1419 }
1420 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1421 {
1422 DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1423 }
1424 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1425 {
1426 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1427 }
1428 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1429 {
1430 DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1431 }
1432 }
1433 else
1434 {
1435#if defined(__gfx94__) or defined(__gfx95__)
1436 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1437 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1438 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1439 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1440 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1441 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1442 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1443 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1444 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1445 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1446 c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1447 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1448#elif defined(__gfx908__) || defined(__gfx90a__)
1449 static_for<0, 8, 1>{}([&](auto k) {
1450 float a_f32 =
1451 type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1452 .template get_as<ADataType>()[number<k>{}]);
1453 float b_f32 =
1454 type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1455 .template get_as<BDataType>()[number<k>{}]);
1456
1457 c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1458 });
1459#else
1460 ck_tile::ignore = c_vec;
1461 ck_tile::ignore = a_vec;
1462 ck_tile::ignore = b_vec;
1463#endif
1464 }
1465 }
1466
1467 // c_vec = a_vec * b_vec
1468 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1469 {
1470#if defined(__gfx94__) or defined(__gfx95__)
1471 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1472 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1473 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1474 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1475 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1476 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1477 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1478 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1479 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1480 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1481 return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1482 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1483#elif defined(__gfx908__) || defined(__gfx90a__)
1484 CVecType c_vec{0.f};
1485 static_for<0, 8, 1>{}([&](auto k) {
1486 float a_f32 =
1487 type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1488 .template get_as<ADataType>()[number<k>{}]);
1489 float b_f32 =
1490 type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1491 .template get_as<BDataType>()[number<k>{}]);
1492
1493 c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1494 });
1495 return c_vec;
1496#else
1497 ck_tile::ignore = a_vec;
1498 ck_tile::ignore = b_vec;
1499 return CVecType{0.f};
1500#endif
1501 }
1502};
1503
1504template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1507template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1510template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1513template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1516
1517template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1520
1521template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1524
1525template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1528
1529template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1531{
1532 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1533 using ADataType = AType_;
1534 using BDataType = BType_;
1535 using CDataType = float;
1536
1540
1541 static constexpr index_t kM = 16;
1542 static constexpr index_t kN = 16;
1543 static constexpr index_t kK = 128;
1544
1545 static constexpr index_t kAMBlock = 1;
1546 static constexpr index_t kBNBlock = 1;
1547
1548 static constexpr index_t kAMLane = 16;
1549 static constexpr index_t kBNLane = 16;
1550 static constexpr index_t kABKLane = 4;
1551 static constexpr index_t kABKPerLane = 32;
1552
1553 static constexpr index_t kCMLane = 4;
1554 static constexpr index_t kCNLane = 16;
1555 static constexpr index_t kCM0PerLane = 1;
1556 static constexpr index_t kCM1PerLane = 4;
1557
1558 // c_vec += a_vec * b_vec
1559 template <bool post_nop_ = false>
1561 const AVecType& a_vec,
1562 const BVecType& b_vec,
1563 bool_constant<post_nop_> = {}) const
1564 {
1565 //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1566 // opsel, scale_b)
1567#if defined(__gfx950__)
1568 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1569 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1570 a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1571 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1572 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1573 a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1574 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1575 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1576 a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1577 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1578 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1579 a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1580#else
1581 ck_tile::ignore = c_vec;
1582 ck_tile::ignore = a_vec;
1583 ck_tile::ignore = b_vec;
1584#endif
1585 }
1586
1587 // c_vec = a_vec * b_vec
1588 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1589 {
1590#if defined(__gfx950__)
1591 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1592 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1593 a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1594 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1595 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1596 a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1597 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1598 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1599 a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1600 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1601 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1602 a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1603#else
1604 ck_tile::ignore = a_vec;
1605 ck_tile::ignore = b_vec;
1606 return CVecType{0.f};
1607#endif
1608 }
1609};
1610
1611template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1614
1615template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1618
1619template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1622
1623template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1626
1627template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1629{
1630 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1633 using CDataType = float;
1634
1638
1639 static constexpr index_t kM = 16;
1640 static constexpr index_t kN = 16;
1641 static constexpr index_t kK = 128;
1642
1643 static constexpr index_t kAMBlock = 1;
1644 static constexpr index_t kBNBlock = 1;
1645
1646 static constexpr index_t kAMLane = 16;
1647 static constexpr index_t kBNLane = 16;
1648 static constexpr index_t kABKLane = 4;
1649 static constexpr index_t kABKPerLane = 32;
1650
1651 static constexpr index_t kCMLane = 4;
1652 static constexpr index_t kCNLane = 16;
1653 static constexpr index_t kCM0PerLane = 1;
1654 static constexpr index_t kCM1PerLane = 4;
1655
1656 // c_vec += a_vec * b_vec
1657 template <index_t opselA, index_t opselB, bool post_nop_ = false>
1659 const AVecType& a_vec,
1660 const int32_t& a_scale,
1661 const BVecType& b_vec,
1662 const int32_t& b_scale,
1663 bool_constant<post_nop_> = {}) const
1664 {
1665 //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1666 // opsel, scale_b)
1667#if defined(__gfx950__)
1668 auto arg_a = bit_cast<int32x4_t>(a_vec);
1669 auto arg_b = bit_cast<int32x4_t>(b_vec);
1670 c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1671 int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1672 int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1673 c_vec,
1674 4,
1675 4,
1676 opselA,
1677 a_scale,
1678 opselB,
1679 b_scale);
1680#else
1681 ck_tile::ignore = c_vec;
1682 ck_tile::ignore = a_vec;
1683 ck_tile::ignore = b_vec;
1684 ck_tile::ignore = a_scale;
1685 ck_tile::ignore = b_scale;
1686#endif
1687 }
1688
1689 // c_vec = a_vec * b_vec
1690 template <index_t opselA, index_t opselB>
1692 const int32_t& a_scale,
1693 const BVecType& b_vec,
1694 const int32_t& b_scale) const
1695 {
1696#if defined(__gfx950__)
1697 auto arg_a = bit_cast<int32x4_t>(a_vec);
1698 auto arg_b = bit_cast<int32x4_t>(b_vec);
1699 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1700 int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1701 int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1702 CVecType{0.f},
1703 4,
1704 4,
1705 opselA,
1706 a_scale,
1707 opselB,
1708 b_scale));
1709#else
1710 ck_tile::ignore = a_vec;
1711 ck_tile::ignore = b_vec;
1712 ck_tile::ignore = a_scale;
1713 ck_tile::ignore = b_scale;
1714 return CVecType{0.f};
1715#endif
1716 }
1717};
1718
1719template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1721{
1722 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1723 using ADataType = AType_;
1724 using BDataType = BType_;
1725 using CDataType = float;
1726
1730
1731 static constexpr index_t kM = 32;
1732 static constexpr index_t kN = 32;
1733 static constexpr index_t kK = 64;
1734
1735 static constexpr index_t kAMBlock = 1;
1736 static constexpr index_t kBNBlock = 1;
1737
1738 static constexpr index_t kAMLane = 32;
1739 static constexpr index_t kBNLane = 32;
1740 static constexpr index_t kABKLane = 2;
1741 static constexpr index_t kABKPerLane = 32;
1742
1743 static constexpr index_t kCMLane = 2;
1744 static constexpr index_t kCNLane = 32;
1745 static constexpr index_t kCM0PerLane = 4;
1746 static constexpr index_t kCM1PerLane = 4;
1747
1748 // c_vec += a_vec * b_vec
1749 template <bool post_nop_ = false>
1751 const AVecType& a_vec,
1752 const BVecType& b_vec,
1753 bool_constant<post_nop_> = {}) const
1754 {
1755 //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1756 // opsel, scale_b)
1757#if defined(__gfx950__)
1758 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1759 c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1760 a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1761 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1762 c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1763 a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1764 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1765 c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1766 a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1767 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1768 c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1769 a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1770#else
1771 ck_tile::ignore = c_vec;
1772 ck_tile::ignore = a_vec;
1773 ck_tile::ignore = b_vec;
1774#endif
1775 }
1776
1777 // c_vec = a_vec * b_vec
1778 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1779 {
1780#if defined(__gfx950__)
1781 if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1782 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1783 a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1784 else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1785 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1786 a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1787 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1788 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1789 a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1790 else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1791 return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1792 a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1793#else
1794 ck_tile::ignore = a_vec;
1795 ck_tile::ignore = b_vec;
1796 return CVecType{0.f};
1797#endif
1798 }
1799};
1800
1801template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1804
1805template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1808
1809template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1812
1813template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1816
1817// int8
1818template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1820{
1821 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1825
1829
1830 static constexpr index_t kM = 32;
1831 static constexpr index_t kN = 32;
1832 static constexpr index_t kK = 16;
1833
1834 static constexpr index_t kAMBlock = 1;
1835 static constexpr index_t kBNBlock = 1;
1836
1837 static constexpr index_t kAMLane = 32;
1838 static constexpr index_t kBNLane = 32;
1839 static constexpr index_t kABKLane = 2;
1840 static constexpr index_t kABKPerLane = 8;
1841
1842 static constexpr index_t kCMLane = 2;
1843 static constexpr index_t kCNLane = 32;
1844 static constexpr index_t kCM0PerLane = 4;
1845 static constexpr index_t kCM1PerLane = 4;
1846
1847 // c_vec += a_vec * b_vec
1848 template <bool post_nop_ = false>
1850 const AVecType& a_vec,
1851 const BVecType& b_vec,
1852 bool_constant<post_nop_> = {}) const
1853 {
1854 DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1855 else
1856 {
1857#if defined(__gfx94__) or defined(__gfx95__)
1858 c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1859 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1860#elif defined(__gfx908__) || defined(__gfx90a__)
1861 static_for<0, 8, 1>{}([&](auto k) {
1862 float a_f32 =
1863 type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1864 .template get_as<ADataType>()[number<k>{}]);
1865 float b_f32 =
1866 type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1867 .template get_as<BDataType>()[number<k>{}]);
1868
1869 c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1870 });
1871#else
1872 ck_tile::ignore = c_vec;
1873 ck_tile::ignore = a_vec;
1874 ck_tile::ignore = b_vec;
1875#endif
1876 }
1877 }
1878
1879 // c_vec = a_vec * b_vec
1880 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1881 {
1882 CVecType c_vec{0};
1883 operator()(c_vec, a_vec, b_vec);
1884 return c_vec;
1885 }
1886};
1887
1888template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1890{
1891 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1895
1899
1900 static constexpr index_t kM = 16;
1901 static constexpr index_t kN = 16;
1902 static constexpr index_t kK = 32;
1903
1904 static constexpr index_t kAMBlock = 1;
1905 static constexpr index_t kBNBlock = 1;
1906
1907 static constexpr index_t kAMLane = 16;
1908 static constexpr index_t kBNLane = 16;
1909 static constexpr index_t kABKLane = 4;
1910 static constexpr index_t kABKPerLane = 8;
1911
1912 static constexpr index_t kCMLane = 4;
1913 static constexpr index_t kCNLane = 16;
1914 static constexpr index_t kCM0PerLane = 1;
1915 static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1916
1917 // c_vec += a_vec * b_vec
1918 template <bool post_nop_ = false>
1920 const AVecType& a_vec,
1921 const BVecType& b_vec,
1922 bool_constant<post_nop_> = {}) const
1923 {
1924 DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1925 else
1926 {
1927#if defined(__gfx94__) or defined(__gfx95__)
1928 c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1929 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1930#else
1931 ck_tile::ignore = c_vec;
1932 ck_tile::ignore = a_vec;
1933 ck_tile::ignore = b_vec;
1934#endif
1935 }
1936 }
1937
1938 // c_vec = a_vec * b_vec
1939 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1940 {
1941 CVecType c_vec{0};
1942 operator()(c_vec, a_vec, b_vec);
1943 return c_vec;
1944 }
1945};
1946
1947template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1949{
1950 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1954
1958
1959 static constexpr index_t kM = 16;
1960 static constexpr index_t kN = 16;
1961 static constexpr index_t kK = 64;
1962
1963 static constexpr index_t kAMBlock = 1;
1964 static constexpr index_t kBNBlock = 1;
1965
1966 static constexpr index_t kAMLane = 16;
1967 static constexpr index_t kBNLane = 16;
1968 static constexpr index_t kABKLane = 4;
1969 static constexpr index_t kABKPerLane = 16;
1970
1971 static constexpr index_t kCMLane = 4;
1972 static constexpr index_t kCNLane = 16;
1973 static constexpr index_t kCM0PerLane = 1;
1974 static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1975
1976 // c_vec += a_vec * b_vec
1977 template <bool post_nop_ = false>
1979 const AVecType& a_vec,
1980 const BVecType& b_vec,
1981 bool_constant<post_nop_> = {}) const
1982 {
1983 DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1984 else
1985 {
1986#if defined(__gfx95__)
1987 c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1988 bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1989#else
1990 ck_tile::ignore = c_vec;
1991 ck_tile::ignore = a_vec;
1992 ck_tile::ignore = b_vec;
1993#endif
1994 }
1995 }
1996
1997 // c_vec = a_vec * b_vec
1998 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1999 {
2000 CVecType c_vec{0};
2001 operator()(c_vec, a_vec, b_vec);
2002 return c_vec;
2003 }
2004};
2005
2006template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
2008{
2009 static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
2013
2017
2018 static constexpr index_t kM = 32;
2019 static constexpr index_t kN = 32;
2020 static constexpr index_t kK = 32;
2021
2022 static constexpr index_t kAMBlock = 1;
2023 static constexpr index_t kBNBlock = 1;
2024
2025 static constexpr index_t kAMLane = 32;
2026 static constexpr index_t kBNLane = 32;
2027 static constexpr index_t kABKLane = 2;
2028 static constexpr index_t kABKPerLane = 16;
2029
2030 static constexpr index_t kCMLane = 2;
2031 static constexpr index_t kCNLane = 32;
2032 static constexpr index_t kCM0PerLane = 4;
2033 static constexpr index_t kCM1PerLane = 4;
2034
2035 // c_vec += a_vec * b_vec
2036 template <bool post_nop_ = false>
2038 const AVecType& a_vec,
2039 const BVecType& b_vec,
2040 bool_constant<post_nop_> = {}) const
2041 {
2042 DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
2043 else
2044 {
2045#if defined(__gfx95__)
2046 c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
2047 a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
2048#else
2049 ck_tile::ignore = c_vec;
2050 ck_tile::ignore = a_vec;
2051 ck_tile::ignore = b_vec;
2052#endif
2053 }
2054 }
2055
2056 // c_vec = a_vec * b_vec
2057 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
2058 {
2059 CVecType c_vec{0};
2060 operator()(c_vec, a_vec, b_vec);
2061 return c_vec;
2062 }
2063};
2064
2065#undef DISPATCH_MFMA_
2066
2067} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition warp_gemm_attribute_mfma_impl.hpp:15
@ Raw_vav
Definition warp_gemm_attribute_mfma_impl.hpp:19
@ Raw_avv
Definition warp_gemm_attribute_mfma_impl.hpp:21
@ Raw_vva
Definition warp_gemm_attribute_mfma_impl.hpp:20
@ Default_
Definition warp_gemm_attribute_mfma_impl.hpp:16
@ Raw_vvv
Definition warp_gemm_attribute_mfma_impl.hpp:17
@ Raw_vaa
Definition warp_gemm_attribute_mfma_impl.hpp:18
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base< bf8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1814
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base< fp8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1514
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< fp8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1511
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< bf8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1526
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< fp8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1505
int8_t int8_t
Definition int8.hpp:20
int32_t int32x8_t
Definition vector_type.hpp:156
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
tuple_array< T, N > thread_buffer
Definition thread_buffer.hpp:14
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base< fp8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1802
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base< fp8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1508
pk_float4_e2m1_t pk_fp4_t
Definition pk_fp4.hpp:151
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base< fp8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1806
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base< fp8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1616
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base< bf8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1624
int32_t int32_t
Definition integer.hpp:10
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base< bf8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1522
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
float fp32x4_t
Definition vector_type.hpp:128
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base< fp8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1612
int32_t index_t
Definition integer.hpp:9
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base< bf8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1620
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base< bf8_t, fp8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8
Definition warp_gemm_attribute_mfma_impl.hpp:1810
float fp32x16_t
Definition vector_type.hpp:130
WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base< bf8_t, bf8_t, Ctrl_ > WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8
Definition warp_gemm_attribute_mfma_impl.hpp:1518
Definition warp_gemm_attribute_mfma_impl.hpp:1531
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1541
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1555
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1546
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1532
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1539
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1550
ext_vector_t< ADataType, 32 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1537
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1553
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1535
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1548
BType_ BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1534
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1556
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1543
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1551
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1545
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1554
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1549
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1542
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1560
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1588
ext_vector_t< BDataType, 32 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1538
AType_ ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1533
Definition warp_gemm_attribute_mfma_impl.hpp:1164
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1178
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1179
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1188
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1174
AType_ ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1166
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1172
ext_vector_t< ADataType, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1170
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1165
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1186
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1183
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1193
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1189
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1168
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1298
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1182
BType_ BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1167
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1187
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1184
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1176
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1181
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1175
ext_vector_t< BDataType, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1171
Definition warp_gemm_attribute_mfma_impl.hpp:1323
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1345
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1324
ext_vector_t< ADataType, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1329
BType_ BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1326
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1348
ext_vector_t< CDataType, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1331
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1343
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1327
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1337
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1346
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1352
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1468
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1340
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1333
AType_ ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1325
ext_vector_t< BDataType, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1330
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1335
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1334
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1347
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1338
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1342
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1341
Definition warp_gemm_attribute_mfma_impl.hpp:1721
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1778
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1732
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1743
BType_ BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1724
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1741
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1731
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1739
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1738
ext_vector_t< CDataType, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1729
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1725
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1750
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1745
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1736
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1735
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1744
ext_vector_t< BDataType, 32 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1728
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1722
AType_ ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1723
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1740
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1746
ext_vector_t< ADataType, 32 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1727
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1733
Definition warp_gemm_attribute_mfma_impl.hpp:1890
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1939
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1898
ext_vector_t< BDataType, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1897
ext_vector_t< ADataType, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1896
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1907
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1919
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1891
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1901
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1905
int8_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1892
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1908
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1910
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1909
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1904
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1912
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1914
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1900
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1915
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1913
int32_t CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1894
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1902
int8_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1893
Definition warp_gemm_attribute_mfma_impl.hpp:1949
int32_t CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1953
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1974
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1961
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1960
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1967
ext_vector_t< ADataType, 16 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1955
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1969
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1973
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1959
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1966
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1978
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1972
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1964
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1968
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1950
int8_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1952
ext_vector_t< BDataType, 16 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1956
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1998
int8_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1951
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1971
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1957
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1963
Definition warp_gemm_attribute_mfma_impl.hpp:1820
int8_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1822
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1849
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1842
ext_vector_t< ADataType, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1826
ext_vector_t< BDataType, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1827
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1844
int32_t CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1824
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1821
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1880
int8_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1823
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1837
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1831
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1835
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1832
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1845
ext_vector_t< CDataType, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1828
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1843
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1840
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1834
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1838
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1839
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1830
Definition warp_gemm_attribute_mfma_impl.hpp:2008
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:2018
int8_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:2010
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:2032
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:2027
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:2037
int8_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:2011
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:2031
int32_t CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:2012
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:2026
ext_vector_t< BDataType, 16 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:2015
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:2023
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:2025
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:2030
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:2020
ext_vector_t< ADataType, 16 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:2014
ext_vector_t< CDataType, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:2016
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:2057
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:2019
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:2033
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:2028
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:2009
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:2022
Definition warp_gemm_attribute_mfma_impl.hpp:666
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:674
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:684
ext_vector_t< bf16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:673
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:695
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:688
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:725
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:678
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:683
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:667
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:689
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:690
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:685
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:680
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:670
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:691
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:677
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:669
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:681
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:668
ext_vector_t< bf16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:672
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:686
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:676
Definition warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:219
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:199
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:204
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:207
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:210
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:220
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:206
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:208
ext_vector_t< bf16_t, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:202
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:216
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:200
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:214
ext_vector_t< bf16_t, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:218
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:198
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:225
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:221
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:244
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:211
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:213
Definition warp_gemm_attribute_mfma_impl.hpp:1049
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1072
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1078
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1053
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1050
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1052
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1063
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1064
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1067
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1074
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1121
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1068
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1073
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1066
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1059
ext_vector_t< bf16_t, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1055
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1060
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1071
ext_vector_t< bf16_t, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1056
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1051
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1069
ext_vector_t< float, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1057
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1061
Definition warp_gemm_attribute_mfma_impl.hpp:577
ext_vector_t< bf16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:584
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:580
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:581
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:592
ext_vector_t< float, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:585
ext_vector_t< bf16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:583
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:637
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:588
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:601
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:587
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:600
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:591
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:595
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:594
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:589
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:599
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:579
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:578
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:606
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:602
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:597
Definition warp_gemm_attribute_mfma_impl.hpp:754
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:775
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:756
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:778
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:777
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:779
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:815
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:757
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:766
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:755
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:773
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:784
ext_vector_t< bf16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:760
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:765
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:764
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:774
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:769
ext_vector_t< bf16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:761
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:780
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:762
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:772
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:768
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:758
Definition warp_gemm_attribute_mfma_impl.hpp:844
bf16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:847
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:867
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:862
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:905
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:845
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:868
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:855
bf16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:846
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:858
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:869
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:859
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:863
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:856
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:848
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:864
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:870
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:865
ext_vector_t< bf16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:850
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:874
ext_vector_t< bf16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:851
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:854
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:852
Definition warp_gemm_attribute_mfma_impl.hpp:322
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:323
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:339
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:336
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:325
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:344
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:347
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:328
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:333
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:324
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:341
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:340
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:346
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:342
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:370
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:351
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:332
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:334
ext_vector_t< fp16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:329
Definition warp_gemm_attribute_mfma_impl.hpp:385
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:396
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:410
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:407
ext_vector_t< fp16_t, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:404
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:405
ext_vector_t< fp16_t, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:392
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:395
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:393
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:433
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:388
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:409
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:400
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:399
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:386
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:403
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:397
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:414
Definition warp_gemm_attribute_mfma_impl.hpp:935
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:936
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:937
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:949
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:957
ext_vector_t< fp16_t, 8 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:941
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:958
ext_vector_t< fp16_t, 8 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:942
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:945
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:939
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:964
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:946
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:955
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:1007
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:954
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:953
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:960
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:950
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:947
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:952
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:959
ext_vector_t< float, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:943
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:938
Definition warp_gemm_attribute_mfma_impl.hpp:259
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:282
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:276
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:283
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:281
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:288
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:284
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:271
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:277
ext_vector_t< float, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:274
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:269
ext_vector_t< fp16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:266
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:270
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:307
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:273
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:265
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:261
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:278
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:279
Definition warp_gemm_attribute_mfma_impl.hpp:448
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:471
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:450
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:497
ext_vector_t< fp16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:455
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:463
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:452
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:478
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:456
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:472
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:451
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:468
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:473
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:449
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:469
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:474
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:462
Definition warp_gemm_attribute_mfma_impl.hpp:512
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:533
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:516
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:522
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:530
fp16_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:515
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:513
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:527
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:542
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:524
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:526
ext_vector_t< float, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:520
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:523
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:535
ext_vector_t< fp16_t, 4 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:518
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:531
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:536
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:537
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:532
fp16_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:514
ext_vector_t< fp16_t, 4 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:519
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:538
Definition warp_gemm_attribute_mfma_impl.hpp:67
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:87
ext_vector_t< ADataType, 1 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:80
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:78
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:97
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:90
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:93
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:85
float ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:76
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:83
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:82
float BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:92
ext_vector_t< BDataType, 1 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:88
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:72
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:116
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:79
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:86
Definition warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:156
float BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:135
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:142
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:157
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:155
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:132
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:161
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:136
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:149
ext_vector_t< CDataType, 16 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:140
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:150
ext_vector_t< ADataType, 1 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:138
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma_impl.hpp:180
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:152
ext_vector_t< BDataType, 1 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:139
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:143
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:146
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:151
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:144
float ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:134
Definition warp_gemm_attribute_mfma_impl.hpp:1629
static constexpr index_t kABKLane
Definition warp_gemm_attribute_mfma_impl.hpp:1648
pk_fp4_t BDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1632
static constexpr index_t kK
Definition warp_gemm_attribute_mfma_impl.hpp:1641
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition warp_gemm_attribute_mfma_impl.hpp:1691
ext_vector_t< BDataType, 16 > BVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1636
float CDataType
Definition warp_gemm_attribute_mfma_impl.hpp:1633
ext_vector_t< CDataType, 4 > CVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1637
static constexpr index_t kN
Definition warp_gemm_attribute_mfma_impl.hpp:1640
static constexpr index_t kABKPerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1649
static constexpr index_t kM
Definition warp_gemm_attribute_mfma_impl.hpp:1639
static constexpr index_t kCNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1652
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1651
static constexpr index_t kBNLane
Definition warp_gemm_attribute_mfma_impl.hpp:1647
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma_impl.hpp:1658
ext_vector_t< ADataType, 16 > AVecType
Definition warp_gemm_attribute_mfma_impl.hpp:1635
static constexpr index_t kCM0PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1653
static constexpr index_t kBNBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1644
static constexpr index_t kAMLane
Definition warp_gemm_attribute_mfma_impl.hpp:1646
static constexpr WGAttrCtlEnum Ctrl
Definition warp_gemm_attribute_mfma_impl.hpp:1630
static constexpr index_t kAMBlock
Definition warp_gemm_attribute_mfma_impl.hpp:1643
static constexpr index_t kCM1PerLane
Definition warp_gemm_attribute_mfma_impl.hpp:1654
pk_fp4_t ADataType
Definition warp_gemm_attribute_mfma_impl.hpp:1631
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition warp_gemm_attribute_mfma_impl.hpp:42