amd_gemm_dpp.hpp Source File

amd_gemm_dpp.hpp Source File#

Composable Kernel: amd_gemm_dpp.hpp Source File
amd_gemm_dpp.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/math.hpp"
9
10namespace ck {
11
12namespace dpp8 {
13
14template <class ABDataType>
16
17template <>
19{
20 // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
21 // single instruction.
22 using a_dtype = half_t;
23 using b_dtype = half_t;
24 using c_dtype = float;
25 static constexpr index_t k_per_instr = 2;
26};
27
28template <index_t MPerThread,
29 index_t NPerThread,
30 index_t KPerThread,
31 class BaseInputType,
32 class AVecDataType,
33 class BVecDataType,
34 class CVecDataType,
35 bool ShareA>
37{
39 using ADataType = typename datatypes_conf::a_dtype;
40 using BDataType = typename datatypes_conf::b_dtype;
41 using CDataType = typename datatypes_conf::c_dtype;
42
43 __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec)
44 {
45 constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread;
46
47 const vector_type<ADataType, KPerThread> a_vector{a_vec};
48 const vector_type<BDataType, KPerThread> b_vector{b_vec};
49
51 float c = c_vec.template AsType<CDataType>()(c_idx);
52 // Next `c_idx` implies that we need to pull data from the next lane.
53 constexpr index_t source_lane = c_idx;
54 static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) {
55 const auto a_k_vec = a_vector.template AsType<AVecDataType>()[k_chunk];
56 const auto b_k_vec = b_vector.template AsType<BVecDataType>()[k_chunk];
59 a_k_vec, b_k_vec, c);
60 });
61 c_vec.template AsType<CDataType>()(c_idx) = c;
62 });
63 }
64};
65
66} // namespace dpp8
67
68} // namespace ck
Definition amd_gemm_dpp.hpp:12
__device__ void inner_product_dpp(const TA &a, const TB &b, TC &c)
Definition inner_product_dpp8.hpp:135
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
_Float16 half_t
Definition data_type.hpp:31
Definition amd_gemm_dpp.hpp:37
typename datatypes_conf::c_dtype CDataType
Definition amd_gemm_dpp.hpp:41
typename datatypes_conf::a_dtype ADataType
Definition amd_gemm_dpp.hpp:39
dpp_datatypes< BaseInputType > datatypes_conf
Definition amd_gemm_dpp.hpp:38
__device__ void Run(const AVecDataType &a_vec, const BVecDataType &b_vec, CVecDataType &c_vec)
Definition amd_gemm_dpp.hpp:43
typename datatypes_conf::b_dtype BDataType
Definition amd_gemm_dpp.hpp:40
half_t a_dtype
Definition amd_gemm_dpp.hpp:22
half_t b_dtype
Definition amd_gemm_dpp.hpp:23
float c_dtype
Definition amd_gemm_dpp.hpp:24
static constexpr index_t k_per_instr
Definition amd_gemm_dpp.hpp:25
Definition amd_gemm_dpp.hpp:15
Definition functional2.hpp:33
Definition dtype_vector.hpp:10