12#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
13 defined(__gfx1103__) || defined(__gfx11_generic__)
17#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
24template <index_t MPerWave, index_t NPerWave>
30 template <
class FloatC>
38 reg_c.template AsType<float8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
39 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
49template <index_t MPerWave, index_t NPerWave>
55 template <
class FloatC>
59 reg_c.template AsType<float8_t>()(
Number<0>{}) =
60 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
61 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
71template <index_t MPerWave, index_t NPerWave, index_t Opsel>
74template <index_t Opsel>
77 template <
class FloatC>
84 reg_c.template AsType<half16_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
85 reg_a, reg_b, reg_c.template AsType<half16_t>()[
Number<0>{}], Opsel);
95template <index_t MPerWave, index_t NPerWave, index_t Opsel>
98template <index_t Opsel>
101 template <
class FloatC>
107#if defined(__gfx11__)
108 reg_c.template AsType<bhalf16_t>()(
Number<0>{}) =
109 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
110 reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[
Number<0>{}], Opsel);
120template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
123template <
bool neg_a,
bool neg_b,
bool clamp>
126 template <
class FloatC>
129#if defined(__gfx11__)
130 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
131 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
136 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
148template <index_t MPerWave, index_t NPerWave>
154 template <
class FloatC>
157#if defined(__gfx11__)
158 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
159 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
169template <index_t MPerWave, index_t NPerWave>
175 template <
class FloatC>
178#if defined(__gfx11__)
179 reg_c.template AsType<float4_t>()(
Number<0>{}) =
180 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
181 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}]);
191template <index_t MPerWave, index_t NPerWave, index_t Opsel>
194template <index_t Opsel>
197 template <
class FloatC>
203#if defined(__gfx11__)
204 reg_c.template AsType<half8_t>()(
Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
205 reg_a, reg_b, reg_c.template AsType<half8_t>()[
Number<0>{}], Opsel);
215template <index_t MPerWave, index_t NPerWave, index_t Opsel>
218template <index_t Opsel>
221 template <
class FloatC>
227#if defined(__gfx11__)
228 reg_c.template AsType<bhalf8_t>()(
Number<0>{}) =
229 __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
230 reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[
Number<0>{}], Opsel);
240template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
243template <
bool neg_a,
bool neg_b,
bool clamp>
246 template <
class FloatC>
249#if defined(__gfx11__)
250 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
251 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
256 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
270template <index_t MPerWave, index_t NPerWave>
276 template <
class FloatC>
283#if defined(__gfx12__)
284 reg_c.template AsType<float8_t>()(
Number<0>{}) =
285 __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(
286 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
296template <index_t MPerWave, index_t NPerWave>
302 template <
class FloatC>
305#if defined(__gfx12__)
306 reg_c.template AsType<float8_t>()(
Number<0>{}) =
307 __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(
308 reg_a, reg_b, reg_c.template AsType<float8_t>()[
Number<0>{}]);
318template <index_t MPerWave, index_t NPerWave,
bool neg_a,
bool neg_b,
bool clamp>
321template <
bool neg_a,
bool neg_b,
bool clamp>
324 template <
class FloatC>
327#if defined(__gfx12__)
328 reg_c.template AsType<int32x8_t>()(
Number<0>{}) =
329 __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
334 reg_c.template AsType<int32x8_t>()[
Number<0>{}],
345template <index_t MPerWave, index_t NPerWave>
351 template <
class FloatC>
352 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
354#if defined(__gfx12__)
355 reg_c.template AsType<float8_t>()(
Number<0>{}) =
356 __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(
359 reg_c.template AsType<float8_t>()[
Number<0>{}]);
369template <index_t MPerWave, index_t NPerWave>
375 template <
class FloatC>
376 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
378#if defined(__gfx12__)
379 reg_c.template AsType<float8_t>()(
Number<0>{}) =
380 __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12(
383 reg_c.template AsType<float8_t>()[
Number<0>{}]);
393template <index_t MPerWave, index_t NPerWave>
399 template <
class FloatC>
400 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
402#if defined(__gfx12__)
403 reg_c.template AsType<float8_t>()(
Number<0>{}) =
404 __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12(
407 reg_c.template AsType<float8_t>()[
Number<0>{}]);
417template <index_t MPerWave, index_t NPerWave>
423 template <
class FloatC>
424 __device__
static void Run(
const bf8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
426#if defined(__gfx12__)
427 reg_c.template AsType<float8_t>()(
Number<0>{}) =
428 __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12(
431 reg_c.template AsType<float8_t>()[
Number<0>{}]);
typename vector_type< int8_t, 8 >::type int8x8_t
Definition dtype_vector.hpp:2178
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition dtype_vector.hpp:2162
typename vector_type< half_t, 16 >::type half16_t
Definition dtype_vector.hpp:2156
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< half_t, 8 >::type half8_t
Definition dtype_vector.hpp:2155
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
typename vector_type< int8_t, 16 >::type int8x16_t
Definition dtype_vector.hpp:2179
typename vector_type< bhalf_t, 16 >::type bhalf16_t
Definition dtype_vector.hpp:2163
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:102
Definition amd_wmma.hpp:96
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:222
Definition amd_wmma.hpp:216
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:78
Definition amd_wmma.hpp:72
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:198
Definition amd_wmma.hpp:192
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:56
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:303
Definition amd_wmma.hpp:297
Definition amd_wmma.hpp:50
static __device__ void Run(const bhalf16_t ®_a, const bhalf16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:176
Definition amd_wmma.hpp:170
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:424
Definition amd_wmma.hpp:418
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:400
Definition amd_wmma.hpp:394
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:31
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:277
Definition amd_wmma.hpp:271
Definition amd_wmma.hpp:25
static __device__ void Run(const half16_t ®_a, const half16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:155
Definition amd_wmma.hpp:149
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:376
Definition amd_wmma.hpp:370
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:352
Definition amd_wmma.hpp:346
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:127
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:325
Definition amd_wmma.hpp:319
Definition amd_wmma.hpp:121
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition amd_wmma.hpp:247
Definition amd_wmma.hpp:241