mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File

mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File#

Composable Kernel: mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp Source File
mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
11{
12 static constexpr auto I0 = number<0>{};
13 static constexpr auto I1 = number<1>{};
14 static constexpr auto I2 = number<2>{};
15
16 static constexpr index_t KBPerLoad = 32;
17
18 static constexpr int MXdlPack = 2;
19 static constexpr int NXdlPack = 2;
20 static constexpr int KXdlPack = 2;
21
22 template <typename Problem>
24 {
25 using namespace ck_tile;
26
29 constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
30 constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
31
32 static_assert(MPerXdl == 16 && NPerXdl == 16);
33 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
34
35 /*reduce transform layers,compare with old ck*/
36 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
37 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
38 constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
39 constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
40
41 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
45 number<1>{});
46
47 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
48 a_lds_block_desc_0,
50 make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
54
55 constexpr auto a_lds_block_desc = transform_tensor_descriptor(
56 a_lds_block_desc_permuted,
62
63 // return a_lds_block_desc_permuted;
64 return a_lds_block_desc;
65 }
66
67 template <typename Problem>
69 {
71
72 constexpr index_t BlockSize = Problem::kBlockSize;
73
74 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
75 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
76
77 constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
78 constexpr index_t K0 = KPerBlock / K1;
79 constexpr index_t M2 = get_warp_size() / K0;
80
81 constexpr index_t M1 = BlockSize / get_warp_size();
82 static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
83 static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
84 constexpr index_t M0 = MPerBlock / (M2 * M1);
85 static_assert(M0 * M1 * M2 == MPerBlock,
86 "Incorrect M0, M2, M1 configuration! "
87 "M0, M1, M2 must cover whole MPerBlock!");
88
96 }
97
98 template <typename Problem>
100 {
101 using TileShape = typename Problem::BlockGemmShape;
102
103 static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
104 static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
105
106 constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
107 constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
108 constexpr int M_Lane = TileShape::WarpTile::at(I0);
109
110 constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
111
112 constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
113
121 sequence<1>>{});
122 }
123
124 template <typename Problem>
126 {
127 using TileShape = typename Problem::BlockGemmShape;
128
129 static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
130
131 constexpr index_t BlockSize = Problem::kBlockSize;
132 constexpr index_t WaveSize = get_warp_size();
133 constexpr index_t WaveNum = BlockSize / WaveSize;
134
135 constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
136 constexpr index_t KWavePerBlk = 1;
137
138 constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
139
140 constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
141
147 // wave in blk, // thd in wave
148 // <M, K> // <M, K>
149 tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
150 tuple<sequence<0, 0, 0>, sequence<1>>, // which index
151 // <repeat, vec_load>
153 sequence<2>>{});
154 }
155
156 template <typename Problem>
158 {
159 using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
160
161 constexpr index_t BlockSize = Problem::kBlockSize;
162 constexpr index_t WaveSize = get_warp_size();
163 constexpr index_t WaveNum = BlockSize / WaveSize;
164
165 constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
166
167 constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
168 constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
169
170 static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
171
172 constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
173 constexpr index_t K_Lanes = 64 / M_Lanes;
174
175 // Y dimension (M) decomposition
176 constexpr index_t Y2 = M_Lanes;
177 constexpr index_t Y1 = M_Warps;
178 constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
179
180 // X dimension (K) decomposition
181 constexpr index_t X0 = K_Lanes;
182 constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
183
190 sequence<0, 1>>{});
191 }
192
193 template <typename Problem>
195 {
196 using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
197
198 constexpr index_t BlockSize = Problem::kBlockSize;
199 constexpr index_t WaveSize = get_warp_size();
200 constexpr index_t WaveNum = BlockSize / WaveSize;
201
202 constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
203
204 constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
205 constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
206
207 static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
208
209 constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
210 constexpr index_t K_Lanes = 64 / N_Lanes;
211
212 // Y dimension (M) decomposition
213 constexpr index_t Y2 = N_Lanes;
214 constexpr index_t Y1 = N_Warps;
215 constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
216
217 // X dimension (K) decomposition
218 constexpr index_t X0 = K_Lanes;
219 constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
220
227 sequence<0, 1>>{});
228 }
229
230 template <typename Problem>
232 {
233 using TileShape = typename Problem::BlockGemmShape;
234
235 constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
236 constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
237 constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
238 constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
239 constexpr index_t MWavePerBlk = M_Warp;
240
243 tuple<sequence<MWavePerBlk, M_Lane>, // second direction
244 sequence<K_Lane, 1>>, // first direction
245 tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
246 tuple<sequence<0, 0>, sequence<0, 1>>, // which index
247 // <repeat, vec_load>
249 sequence<1>>{});
250 }
251
252 template <typename Problem>
254 {
255 using TileShape = typename Problem::BlockGemmShape;
256
257 constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
258 constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
259 constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
260 constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
261 constexpr index_t NWavePerBlk = N_Warp;
262
265 tuple<sequence<NWavePerBlk, N_Lane>, // second direction
266 sequence<K_Lane, 1>>, // first direction
267 tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
268 tuple<sequence<0, 0>, sequence<0, 1>>, // which index
269 // <repeat, vec_load>
271 sequence<1>>{});
272 }
273};
274
275} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:11
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ADramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:68
static constexpr auto I2
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:231
static constexpr auto I0
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:12
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:253
static constexpr index_t KBPerLoad
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_BFlatDramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:125
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:194
static constexpr int NXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:19
static constexpr int KXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:20
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:157
static CK_TILE_HOST_DEVICE constexpr auto MakeMXFP4_ALdsBlockDescriptor()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:23
static constexpr auto I1
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:13
static CK_TILE_HOST_DEVICE constexpr auto MakeMXF4_ALDS_TileDistribution()
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:99
static constexpr int MXdlPack
Definition mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:18
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetSmemPackA()
Definition flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp:233
static constexpr int PackedSize
Definition tile/core/numeric/numeric.hpp:82
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192