topk_softmax_warp_per_row_problem.hpp Source File

topk_softmax_warp_per_row_problem.hpp Source File#

Composable Kernel: topk_softmax_warp_per_row_problem.hpp Source File
topk_softmax_warp_per_row_problem.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7#include <string>
8#include <type_traits>
9
10namespace ck_tile {
11
12template <typename InputType_,
13 typename WeightType_,
14 typename IndexType_,
15 index_t Experts_,
16 bool ActivationIsSoftmax_ = true, // false: sigmoid
17 index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
18 index_t BytesPerIssue_ = sizeof(InputType_),
19 index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
20 index_t BlockSize_ = 256>
22{
23 // TODO: this kernel only support warp per row
27
28 static constexpr index_t LaunchType = LaunchType_;
29 static constexpr index_t Experts = Experts_;
30 static constexpr index_t BytesPerIssue = BytesPerIssue_;
31 static constexpr index_t IssuesPerCol = IssuesPerCol_;
32 static constexpr index_t BlockSize = BlockSize_;
33 static constexpr index_t WarpSize = get_warp_size();
34
35 static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_;
36
37 static_assert(BytesPerIssue % sizeof(InputType) == 0);
38 static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
39 static_assert(Experts % VectorSize == 0);
41 static_assert(WarpSize % LanesPerRow == 0);
45
48};
49} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
Definition topk_softmax_warp_per_row_problem.hpp:22
remove_cvref_t< InputType_ > InputType
Definition topk_softmax_warp_per_row_problem.hpp:24
static constexpr index_t RowsPerWarp
Definition topk_softmax_warp_per_row_problem.hpp:43
static constexpr bool ActivationIsSoftmax
Definition topk_softmax_warp_per_row_problem.hpp:35
static constexpr index_t VectorSize
Definition topk_softmax_warp_per_row_problem.hpp:38
static constexpr index_t LanesPerRow
Definition topk_softmax_warp_per_row_problem.hpp:40
remove_cvref_t< IndexType_ > IndexType
Definition topk_softmax_warp_per_row_problem.hpp:26
static constexpr index_t RowsPerWarpPerColIssue
Definition topk_softmax_warp_per_row_problem.hpp:42
static constexpr index_t IssuesPerRow
Definition topk_softmax_warp_per_row_problem.hpp:44
static constexpr index_t BlockSize
Definition topk_softmax_warp_per_row_problem.hpp:32
static constexpr index_t WarpSize
Definition topk_softmax_warp_per_row_problem.hpp:33
static constexpr index_t RowsPerBlock
Definition topk_softmax_warp_per_row_problem.hpp:47
static constexpr index_t WarpsPerBlock
Definition topk_softmax_warp_per_row_problem.hpp:46
static constexpr index_t LaunchType
Definition topk_softmax_warp_per_row_problem.hpp:28
static constexpr index_t BytesPerIssue
Definition topk_softmax_warp_per_row_problem.hpp:30
static constexpr index_t Experts
Definition topk_softmax_warp_per_row_problem.hpp:29
static constexpr index_t IssuesPerCol
Definition topk_softmax_warp_per_row_problem.hpp:31
remove_cvref_t< WeightType_ > WeightType
Definition topk_softmax_warp_per_row_problem.hpp:25