fused_moegemm_tile_partitioner.hpp Source File

fused_moegemm_tile_partitioner.hpp Source File#

Composable Kernel: fused_moegemm_tile_partitioner.hpp Source File
fused_moegemm_tile_partitioner.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6namespace ck_tile {
7
8template <typename BlockShape_>
10{
11 // FusedMoeGemmShape
13
14 static constexpr const char* name = "lin";
15
16 CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
17 ck_tile::index_t /*intermediate_size*/)
18 {
19 index_t i_n = blockIdx.x;
20 index_t i_m = blockIdx.y;
21
22 return ck_tile::make_tuple(i_m, i_n);
23 }
24
25 CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
26 {
27 // TODO: this may need tuning
28 index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
29 index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
30 return dim3(ns, ms, 1);
31 }
32};
33} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition fused_moegemm_tile_partitioner.hpp:10
CK_TILE_DEVICE auto operator()(ck_tile::index_t, ck_tile::index_t)
Definition fused_moegemm_tile_partitioner.hpp:16
ck_tile::remove_cvref_t< BlockShape_ > BlockShape
Definition fused_moegemm_tile_partitioner.hpp:12
static constexpr const char * name
Definition fused_moegemm_tile_partitioner.hpp:14
static CK_TILE_HOST constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
Definition fused_moegemm_tile_partitioner.hpp:25