grouped_flatmm_kernel.hpp Source File

grouped_flatmm_kernel.hpp Source File#

Composable Kernel: grouped_flatmm_kernel.hpp Source File
grouped_flatmm_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <string>
8
9#include "ck_tile/core.hpp"
12
13namespace ck_tile {
14
15template <class ScaleM = FlatmmScalePointer<-1>,
16 class ScaleN = FlatmmScalePointer<-1>,
17 index_t NumDTensor = 0>
19{
22 index_t* M_,
23 index_t* N_,
24 index_t* K_,
25 const void** a_ptr_,
26 index_t* stride_A_,
27 const void** b_shuffle_ptr_,
28 index_t* stride_B_,
29 const std::array<const void*, NumDTensor>& ds_ptr_,
30 const std::array<index_t, NumDTensor>& stride_Ds_,
31 void** c_ptr_,
32 index_t* stride_C_,
33 index_t k_batch_,
34 ScaleM* scale_m_ = nullptr,
35 ScaleN* scale_n_ = nullptr)
36 : group_count(group_count_),
37 M(M_),
38 N(N_),
39 K(K_),
40 a_ptr(a_ptr_),
41 stride_A(stride_A_),
42 b_shuffle_ptr(b_shuffle_ptr_),
43 stride_B(stride_B_),
44 ds_ptr(ds_ptr_),
45 stride_Ds(stride_Ds_),
46 c_ptr(c_ptr_),
47 stride_C(stride_C_),
48 k_batch(k_batch_),
49 scale_m(scale_m_),
50 scale_n(scale_n_)
51 {
52 }
53
58 const void** a_ptr;
60 const void** b_shuffle_ptr;
62 const std::array<const void*, NumDTensor> ds_ptr;
63 const std::array<index_t, NumDTensor> stride_Ds;
64 union
65 {
66 void** e_ptr;
67 void** c_ptr;
68 };
71 ScaleM* scale_m = nullptr;
72 ScaleN* scale_n = nullptr;
73};
74
75template <class ScaleM = FlatmmScalePointer<-1>,
76 class ScaleN = FlatmmScalePointer<-1>,
77 index_t NumDTensor = 0>
79{
82 index_t M_,
83 index_t N_,
84 index_t K_,
85 const void* a_ptr_,
86 index_t stride_A_,
87 const void* b_shuffle_ptr_,
88 index_t stride_B_,
89 const std::array<const void*, NumDTensor>& ds_ptr_,
90 const std::array<index_t, NumDTensor>& stride_Ds_,
91 void* c_ptr_,
92 index_t stride_C_,
93 index_t k_batch_,
94 ScaleM scale_m_ = nullptr,
95 ScaleN scale_n_ = nullptr)
96 : group_count(1),
97 M_indices(M_indices_),
98 M(M_),
99 N(N_),
100 K(K_),
101 a_ptr(a_ptr_),
102 stride_A(stride_A_),
103 b_shuffle_ptr(b_shuffle_ptr_),
104 stride_B(stride_B_),
105 ds_ptr(ds_ptr_),
106 stride_Ds(stride_Ds_),
107 c_ptr(c_ptr_),
108 stride_C(stride_C_),
109 k_batch(k_batch_),
110 scale_m(scale_m_),
111 scale_n(scale_n_)
112 {
113 }
119 const void* a_ptr;
121 const void* b_shuffle_ptr;
123 const std::array<const void*, NumDTensor> ds_ptr;
124 const std::array<index_t, NumDTensor> stride_Ds;
125 union
126 {
127 void* e_ptr;
128 void* c_ptr;
129 };
132 ScaleM scale_m = nullptr;
133 ScaleN scale_n = nullptr;
134};
135
136template <class ScaleM = FlatmmScalePointer<-1>,
137 class ScaleN = FlatmmScalePointer<-1>,
138 index_t NumDTensor = 0>
140{
143 index_t group_count_,
144 index_t Max_M_,
145 index_t N_,
146 index_t K_,
147 const void* a_ptr_,
148 index_t stride_A_,
149 const void* b_shuffle_ptr_,
150 index_t stride_B_,
151 const std::array<const void*, NumDTensor>& ds_ptr_,
152 const std::array<index_t, NumDTensor>& stride_Ds_,
153 void* c_ptr_,
154 index_t stride_C_,
155 index_t k_batch_,
156 ScaleM scale_m_ = nullptr,
157 ScaleN scale_n_ = nullptr)
158 : M_indices(M_indices_),
159 group_count(group_count_),
160 M(Max_M_),
161 N(N_),
162 K(K_),
163 a_ptr(a_ptr_),
164 stride_A(stride_A_),
165 b_shuffle_ptr(b_shuffle_ptr_),
166 stride_B(stride_B_),
167 ds_ptr(ds_ptr_),
168 stride_Ds(stride_Ds_),
169 c_ptr(c_ptr_),
170 stride_C(stride_C_),
171 k_batch(k_batch_),
172 scale_m(scale_m_),
173 scale_n(scale_n_)
174 {
175 }
176
182 const void* a_ptr;
184 const void* b_shuffle_ptr;
186 const std::array<const void*, NumDTensor> ds_ptr;
187 const std::array<index_t, NumDTensor> stride_Ds;
188 union
189 {
190 void* e_ptr;
191 void* c_ptr;
192 };
195 ScaleM scale_m = nullptr;
196 ScaleN scale_n = nullptr;
197};
198
199template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
200struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
201{
204
207
209
212 // Below type is actually accumulation data type - the output of block GEMM.
216
217 static constexpr index_t NumDTensor = DsDataType::size();
218 static constexpr index_t kBlockSize = FlatmmPipeline_::BlockSize;
219
220 static constexpr auto I0 = number<0>();
221 static constexpr auto I1 = number<1>();
222 static constexpr auto I2 = number<2>();
223 static constexpr auto I3 = number<3>();
224
225 static_assert(DsLayout::size() == DsDataType::size(),
226 "The size of DsLayout and DsDataType should be the same");
227
228 CK_TILE_HOST static const std::string GetName()
229 {
230 return concat(
231 '_', "grouped_flatmm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
232 }
233
234 template <class ScaleM = FlatmmScalePointer<-1>,
235 class ScaleN = FlatmmScalePointer<-1>,
237 CK_TILE_HOST_DEVICE static auto
239 {
240 hipDeviceProp_t prop;
241 int deviceId = 0; // default device
242
243 constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
244 int dync_smem_size = 0;
245 int maxActiveBlocksPerCU;
246
247 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
248
249 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
250 &maxActiveBlocksPerCU,
251 reinterpret_cast<void*>(
253 block_size,
254 dync_smem_size);
255
256 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
257
258 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
259 // << ", persistent_block_size: " << persistent_block_size << std::endl;
260
261 assert(kernelArgs.k_batch == 1);
262 return dim3(persistent_block_size, 1, kernelArgs.k_batch);
263 }
264
265 template <class ScaleM = FlatmmScalePointer<-1>,
266 class ScaleN = FlatmmScalePointer<-1>,
268 CK_TILE_HOST_DEVICE static auto
270 kernelArgs)
271 {
272 hipDeviceProp_t prop;
273 int deviceId = 0; // default device
274
275 constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
276 int dync_smem_size = 0;
277 int maxActiveBlocksPerCU;
278
279 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
280
281 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
282 &maxActiveBlocksPerCU,
283 reinterpret_cast<void*>(
284 kentry<1,
287 block_size,
288 dync_smem_size);
289
290 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
291 const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
292
293 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
294 // << ", persistent_block_size: " << persistent_block_size
295 // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
296
297 assert(kernelArgs.k_batch == 1);
298 return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kernelArgs.k_batch);
299 }
300
301 template <class ScaleM = FlatmmScalePointer<-1>,
302 class ScaleN = FlatmmScalePointer<-1>,
305 [[maybe_unused]] const MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN, NumDTensor>& kernelArgs)
306 {
307 hipDeviceProp_t prop;
308 int deviceId = 0; // default device
309
310 constexpr int block_size = UnderlyingGemmKernel::BlockSize().x;
311 int dync_smem_size = 0;
312 int maxActiveBlocksPerCU;
313
314 [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
315
316 e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
317 &maxActiveBlocksPerCU,
318 reinterpret_cast<void*>(
319 kentry<1,
322 block_size,
323 dync_smem_size);
324
325 const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
326 // const int total_work_tile_cnt = TilePartitioner::GridSize(kernelArgs.M, kernelArgs.N);
327
328 // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
329 // << ", persistent_block_size: " << persistent_block_size << std::endl;
330
331 assert(kernelArgs.k_batch == 1);
332 return dim3(persistent_block_size, 1, kernelArgs.k_batch);
333 }
334
335 template <typename HostArgs>
336 CK_TILE_HOST static constexpr auto MakeKernelArgs(const HostArgs& hostArgs)
337 {
338 return hostArgs;
339 }
340 // CK_TILE_HOST static constexpr auto
341 // MakeKernelArgs(const ContiguousGroupedFlatmmHostArgs& hostArgs)
342 // {
343 // return hostArgs;
344 // }
345 // CK_TILE_HOST static constexpr auto
346 // MakeKernelArgs(const MaskedGroupedFlatmmHostArgs& hostArgs)
347 // {
348 // return hostArgs;
349 // }
350
351 template <class ScaleM = FlatmmScalePointer<-1>,
352 class ScaleN = FlatmmScalePointer<-1>,
355 {
356 int group_idx = 0;
357 int block_linear_idx = blockIdx.x;
358 int total_block_cnt = gridDim.x;
359
360 UnderlyingGemmKernel underlying_kernel{};
361 for(; group_idx < kargs.group_count; ++group_idx)
362 {
363 const index_t M = kargs.M[group_idx];
364 const index_t N = kargs.N[group_idx];
365 const index_t group_block_cnt = TilePartitioner::GridSize(M, N);
366
367 while(block_linear_idx < group_block_cnt)
368 {
369 // Found the group this block belongs to
370 // create the kernel args for the underlying flatmm kernel
372 kargs.a_ptr[group_idx],
373 kargs.b_shuffle_ptr[group_idx],
374 kargs.ds_ptr,
375 kargs.c_ptr[group_idx],
376 kargs.M[group_idx],
377 kargs.N[group_idx],
378 kargs.K[group_idx],
379 kargs.stride_A[group_idx],
380 kargs.stride_B[group_idx],
381 kargs.stride_Ds,
382 kargs.stride_C[group_idx],
383 kargs.k_batch,
384 kargs.scale_m[group_idx],
385 kargs.scale_n[group_idx]};
386 // call the underlying flatmm kernel
387 underlying_kernel(impl_kargs, block_linear_idx);
388 block_linear_idx += total_block_cnt;
389 }
390 block_linear_idx -= group_block_cnt;
391 }
392 }
393
394 template <class ScaleM = FlatmmScalePointer<-1>,
395 class ScaleN = FlatmmScalePointer<-1>,
397 CK_TILE_DEVICE void
399 {
400 int block_linear_idx = blockIdx.x;
401 int total_block_cnt = gridDim.x;
402 int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
403
404 UnderlyingGemmKernel underlying_kernel{};
405 for(; block_linear_idx < total_work_tile_cnt; block_linear_idx += total_block_cnt)
406 {
407 auto [block_m_idx, block_n_idx] =
408 TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(block_linear_idx);
409 // get the group index from the M_indices
410 int group_idx = kargs.M_indices[block_m_idx * BlockGemmShape::kM];
411
413 kargs.a_ptr,
414 static_cast<const BDataType*>(kargs.b_shuffle_ptr) + group_idx * kargs.N * kargs.K,
415 kargs.ds_ptr,
416 kargs.c_ptr,
417 kargs.M,
418 kargs.N,
419 kargs.K,
420 kargs.stride_A,
421 kargs.stride_B,
422 kargs.stride_Ds,
423 kargs.stride_C,
424 kargs.k_batch,
425 kargs.scale_m,
426 kargs.scale_n};
427 // call the underlying flatmm kernel
428 underlying_kernel(impl_kargs, block_linear_idx);
429 }
430 }
431
432 template <class ScaleM = FlatmmScalePointer<-1>,
433 class ScaleN = FlatmmScalePointer<-1>,
435 CK_TILE_DEVICE void
437 {
438 int group_idx = 0;
439 int block_linear_idx = blockIdx.x;
440 int total_block_cnt = gridDim.x;
441
442 UnderlyingGemmKernel underlying_kernel{};
443 for(; group_idx < kargs.group_count; ++group_idx)
444 {
445 const index_t valid_M = kargs.M_indices[group_idx];
446 const index_t N = kargs.N;
447 const index_t group_block_cnt = TilePartitioner::GridSize(valid_M, N);
448
449 while(block_linear_idx < group_block_cnt)
450 {
451 // Found the group this block belongs to
452 // create the kernel args for the underlying flatmm kernel
454 static_cast<const ADataType*>(kargs.a_ptr) + group_idx * kargs.M * kargs.K,
455 static_cast<const BDataType*>(kargs.b_shuffle_ptr) +
456 group_idx * kargs.N * kargs.K,
457 kargs.ds_ptr,
458 static_cast<CDataType*>(kargs.c_ptr) + group_idx * kargs.M * kargs.N,
459 valid_M,
460 kargs.N,
461 kargs.K,
462 kargs.stride_A,
463 kargs.stride_B,
464 kargs.stride_Ds,
465 kargs.stride_C,
466 kargs.k_batch,
467 kargs.scale_m + group_idx * kargs.M,
468 kargs.scale_n + group_idx * kargs.N};
469 // call the underlying flatmm kernel
470 underlying_kernel(impl_kargs, block_linear_idx);
471 block_linear_idx += total_block_cnt;
472 }
473 block_linear_idx -= group_block_cnt;
474 }
475 }
476};
477
478} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
__global__ void kentry(Args... args)
Definition tile/host/kernel_launch.hpp:22
std::string gemm_prec_str()
Definition utils.hpp:31
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
Definition grouped_flatmm_kernel.hpp:79
void * c_ptr
Definition grouped_flatmm_kernel.hpp:128
void * e_ptr
Definition grouped_flatmm_kernel.hpp:127
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:124
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:123
const void * b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:121
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t *M_indices_, index_t M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:81
const void * a_ptr
Definition grouped_flatmm_kernel.hpp:119
index_t N
Definition grouped_flatmm_kernel.hpp:117
index_t M
Definition grouped_flatmm_kernel.hpp:116
index_t k_batch
Definition grouped_flatmm_kernel.hpp:131
ScaleM scale_m
Definition grouped_flatmm_kernel.hpp:132
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs()=default
index_t group_count
Definition grouped_flatmm_kernel.hpp:114
index_t K
Definition grouped_flatmm_kernel.hpp:118
index_t stride_B
Definition grouped_flatmm_kernel.hpp:122
index_t stride_A
Definition grouped_flatmm_kernel.hpp:120
ScaleN scale_n
Definition grouped_flatmm_kernel.hpp:133
index_t stride_C
Definition grouped_flatmm_kernel.hpp:130
index_t * M_indices
Definition grouped_flatmm_kernel.hpp:115
Definition flatmm_kernel.hpp:229
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition flatmm_kernel.hpp:252
Definition flatmm_kernel.hpp:33
Definition grouped_flatmm_kernel.hpp:19
index_t * stride_A
Definition grouped_flatmm_kernel.hpp:59
index_t * N
Definition grouped_flatmm_kernel.hpp:56
const void ** b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:60
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_, index_t *M_, index_t *N_, index_t *K_, const void **a_ptr_, index_t *stride_A_, const void **b_shuffle_ptr_, index_t *stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void **c_ptr_, index_t *stride_C_, index_t k_batch_, ScaleM *scale_m_=nullptr, ScaleN *scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:21
index_t * stride_B
Definition grouped_flatmm_kernel.hpp:61
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:63
ScaleM * scale_m
Definition grouped_flatmm_kernel.hpp:71
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:62
index_t k_batch
Definition grouped_flatmm_kernel.hpp:70
index_t * stride_C
Definition grouped_flatmm_kernel.hpp:69
CK_TILE_HOST GroupedFlatmmHostArgs()=default
const void ** a_ptr
Definition grouped_flatmm_kernel.hpp:58
ScaleN * scale_n
Definition grouped_flatmm_kernel.hpp:72
index_t group_count
Definition grouped_flatmm_kernel.hpp:54
index_t * K
Definition grouped_flatmm_kernel.hpp:57
index_t * M
Definition grouped_flatmm_kernel.hpp:55
void ** e_ptr
Definition grouped_flatmm_kernel.hpp:66
void ** c_ptr
Definition grouped_flatmm_kernel.hpp:67
Definition grouped_flatmm_kernel.hpp:201
static constexpr index_t NumDTensor
Definition grouped_flatmm_kernel.hpp:217
static CK_TILE_HOST_DEVICE auto GridSize(const ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:269
static constexpr index_t kBlockSize
Definition grouped_flatmm_kernel.hpp:218
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_flatmm_kernel.hpp:205
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition grouped_flatmm_kernel.hpp:206
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition grouped_flatmm_kernel.hpp:210
FlatmmKernel< TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_ > UnderlyingGemmKernel
Definition grouped_flatmm_kernel.hpp:202
static CK_TILE_HOST const std::string GetName()
Definition grouped_flatmm_kernel.hpp:228
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition grouped_flatmm_kernel.hpp:213
static CK_TILE_HOST constexpr auto MakeKernelArgs(const HostArgs &hostArgs)
Definition grouped_flatmm_kernel.hpp:336
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:398
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:354
static constexpr auto I1
Definition grouped_flatmm_kernel.hpp:221
static CK_TILE_HOST_DEVICE auto GridSize(const MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:304
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition grouped_flatmm_kernel.hpp:214
static constexpr auto I3
Definition grouped_flatmm_kernel.hpp:223
static CK_TILE_HOST_DEVICE auto GridSize(const GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:238
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_flatmm_kernel.hpp:215
static constexpr auto I0
Definition grouped_flatmm_kernel.hpp:220
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:436
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_flatmm_kernel.hpp:208
static constexpr auto I2
Definition grouped_flatmm_kernel.hpp:222
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition grouped_flatmm_kernel.hpp:211
typename UnderlyingGemmKernel::BlockGemmShape BlockGemmShape
Definition grouped_flatmm_kernel.hpp:203
Definition grouped_flatmm_kernel.hpp:140
index_t group_count
Definition grouped_flatmm_kernel.hpp:178
void * e_ptr
Definition grouped_flatmm_kernel.hpp:190
ScaleM scale_m
Definition grouped_flatmm_kernel.hpp:195
ScaleN scale_n
Definition grouped_flatmm_kernel.hpp:196
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t *M_indices_, index_t group_count_, index_t Max_M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:142
CK_TILE_HOST MaskedGroupedFlatmmHostArgs()=default
index_t * M_indices
Definition grouped_flatmm_kernel.hpp:177
index_t N
Definition grouped_flatmm_kernel.hpp:180
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:186
index_t k_batch
Definition grouped_flatmm_kernel.hpp:194
index_t K
Definition grouped_flatmm_kernel.hpp:181
const void * b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:184
index_t stride_C
Definition grouped_flatmm_kernel.hpp:193
const void * a_ptr
Definition grouped_flatmm_kernel.hpp:182
index_t stride_A
Definition grouped_flatmm_kernel.hpp:183
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:187
index_t M
Definition grouped_flatmm_kernel.hpp:179
index_t stride_B
Definition grouped_flatmm_kernel.hpp:185
void * c_ptr
Definition grouped_flatmm_kernel.hpp:191