tile_fmha_shape.hpp Source File#
tile_fmha_shape.hpp
Go to the documentation of this file.
111 static constexpr index_t kK4 = BlockTile::at(number<6>{}); // tile size along gemm4(dS@K) unroll
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x)
Definition tile/core/numeric/math.hpp:462
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition tile/core/container/sequence.hpp:982
Definition tile_fmha_shape.hpp:82
static constexpr index_t kQKHeaddim
Definition tile_fmha_shape.hpp:112
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition tile_fmha_shape.hpp:84
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition tile_fmha_shape.hpp:87
static constexpr index_t kMaxSeqLenQ
Definition tile_fmha_shape.hpp:118
remove_cvref_t< Gemm4WarpTile_ > Gemm4WarpTile
Definition tile_fmha_shape.hpp:93
remove_cvref_t< Gemm4BlockWarps_ > Gemm4BlockWarps
Definition tile_fmha_shape.hpp:92
static constexpr index_t kVHeaddim
Definition tile_fmha_shape.hpp:115
remove_cvref_t< Gemm2BlockWarps_ > Gemm2BlockWarps
Definition tile_fmha_shape.hpp:88
remove_cvref_t< Gemm3BlockWarps_ > Gemm3BlockWarps
Definition tile_fmha_shape.hpp:90
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition tile_fmha_shape.hpp:85
remove_cvref_t< Gemm2WarpTile_ > Gemm2WarpTile
Definition tile_fmha_shape.hpp:89
remove_cvref_t< BlockTile_ > BlockTile
Definition tile_fmha_shape.hpp:83
static constexpr index_t NumWarps
Definition tile_fmha_shape.hpp:95
remove_cvref_t< Gemm3WarpTile_ > Gemm3WarpTile
Definition tile_fmha_shape.hpp:91
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition tile_fmha_shape.hpp:86
Definition tile_fmha_shape.hpp:35
static constexpr bool IsVLayoutRowMajor
Definition tile_fmha_shape.hpp:63
static constexpr index_t kQKHeaddim
Definition tile_fmha_shape.hpp:55
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition tile_fmha_shape.hpp:39
static constexpr index_t NumGemm0Warps
Definition tile_fmha_shape.hpp:42
remove_cvref_t< Gemm1WarpTile_ > Gemm1WarpTile
Definition tile_fmha_shape.hpp:40
remove_cvref_t< Gemm0WarpTile_ > Gemm0WarpTile
Definition tile_fmha_shape.hpp:38
static constexpr index_t kSubQKHeaddim
Definition tile_fmha_shape.hpp:60
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition tile_fmha_shape.hpp:37
std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition tile_fmha_shape.hpp:64
remove_cvref_t< BlockTile_ > BlockTile
Definition tile_fmha_shape.hpp:36
static constexpr index_t NumGemm1Warps
Definition tile_fmha_shape.hpp:44
Definition tile/core/numeric/math.hpp:98
Definition tile/ops/common/tensor_layout.hpp:22
Definition tile/ops/common/tensor_layout.hpp:17