inner_product_dpp8.hpp Source File

inner_product_dpp8.hpp Source File#

Composable Kernel: inner_product_dpp8.hpp Source File
inner_product_dpp8.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "amd_gemm_dpp.hpp"
7#include "data_type.hpp"
8#include "type_convert.hpp"
9
10namespace ck {
11
12namespace dpp8 {
13
16
17template <int SrcLaneIdx>
18__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
19
20// clang-format off
21template <>
22__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){
23 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
24}
25template <>
26__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){
27 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
28}
29template <>
30__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){
31 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
32}
33template <>
34__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){
35 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
36}
37template <>
38__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){
39 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
40}
41template <>
42__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){
43 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
44}
45template <>
46__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){
47 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
48}
49template <>
50__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){
51 asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
52}
53// clang-format on
54
58template <int SrcLaneIdx, bool ShareA>
59__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c)
60{
61 static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
62 "DPP8 src broadcast lane out of range <0, 7>.");
63 if constexpr(ShareA)
64 {
66 }
67 else
68 {
70 }
71}
72
77constexpr std::array<int, dpp8::lane_group_size> IntrinsicMaskDpp8 = {
78 0, // 0, 0, 0, 0, 0, 0, 0, 0
79 2396745, // 1, 1, 1, 1, 1, 1, 1, 1
80 4793490, // 2, 2, 2, 2, 2, 2, 2, 2
81 7190235, // 3, 3, 3, 3, 3, 3, 3, 3
82 9586980, // 4, 4, 4, 4, 4, 4, 4, 4
83 11983725, // 5, 5, 5, 5, 5, 5, 5, 5
84 14380470, // 6, 6, 6, 6, 6, 6, 6, 6
85 16777215, // 7, 7, 7, 7, 7, 7, 7, 7
86};
87
91template <int SrcLaneIdx>
93{
94 static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
95 "DPP8 src broadcast lane out of range <0, 7>.");
96 return IntrinsicMaskDpp8[SrcLaneIdx];
97}
98
99template <int SrcLaneIdx>
100__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c)
101{
102 constexpr int sel_mask = get_dpp_sel_mask_broadcast<SrcLaneIdx>();
103 const half2_t val_from_other_lane =
104 bit_cast<half2_t>(__builtin_amdgcn_mov_dpp8(bit_cast<int>(a), sel_mask));
105 c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false);
106}
107
111template <int SrcLaneIdx, bool ShareA>
112__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c)
113{
114 if constexpr(ShareA)
115 {
117 }
118 else
119 {
121 }
122}
123
134template <typename TA, typename TB, typename TC, int SrcLaneIdx, bool ShareA>
135__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c)
136{
137#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM
139#else
141#endif
142}
143
144} // namespace dpp8
145
146} // namespace ck
Definition amd_gemm_dpp.hpp:12
__device__ void inline_v_dot2c_dpp8_instr< 5 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:42
__device__ void inline_v_dot2c_dpp8_instr< 4 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:38
__device__ void inner_product_dpp(const TA &a, const TB &b, TC &c)
Definition inner_product_dpp8.hpp:135
__device__ void inline_v_dot2c_dpp8(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:59
__device__ void inline_v_dot2c_dpp8_instr(const half2_t &a, const half2_t &b, float &c)
__device__ void inline_v_dot2c_dpp8_instr< 3 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:34
__device__ void inline_v_dot2c_dpp8_instr< 2 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:30
constexpr std::array< int, dpp8::lane_group_size > IntrinsicMaskDpp8
Definition inner_product_dpp8.hpp:77
__device__ void intrinsic_fdot2(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:112
__device__ void intrinsic_fdot2_impl(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:100
__device__ void inline_v_dot2c_dpp8_instr< 7 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:50
__device__ void inline_v_dot2c_dpp8_instr< 6 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:46
constexpr int get_dpp_sel_mask_broadcast()
Definition inner_product_dpp8.hpp:92
__device__ void inline_v_dot2c_dpp8_instr< 0 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:22
constexpr index_t lane_group_size
Number of lanes that can share data using DPP8 modifiers.
Definition inner_product_dpp8.hpp:15
__device__ void inline_v_dot2c_dpp8_instr< 1 >(const half2_t &a, const half2_t &b, float &c)
Definition inner_product_dpp8.hpp:26
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename vector_type< half_t, 2 >::type half2_t
Definition dtype_vector.hpp:2153
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517