#include <fmha_bwd_kernel.hpp>
|
| static CK_TILE_HOST std::string | GetName () |
| template<bool Cond = !kIsGroupMode> |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, ck_tile::index_t batch_stride_d) |
| template<bool Cond = kIsGroupMode> |
| static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > | MakeKargs (const void *o_ptr, const void *do_ptr, void *d_ptr, float p_undrop, const void *seqstart_q_ptr, const void *seqlen_q_ptr, const void *cu_seqlen_q_ptr, ck_tile::index_t hdim_v, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_d) |
| static CK_TILE_HOST constexpr auto | GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) |
| static CK_TILE_DEVICE constexpr auto | GetTileIndex () |
| static CK_TILE_HOST dim3 | BlockSize () |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
◆ DDataType
template<typename FmhaBwdOGradDotO_>
◆ FmhaBwdOGradDotO
template<typename FmhaBwdOGradDotO_>
◆ Kargs
template<typename FmhaBwdOGradDotO_>
Initial value: std::
conditional_t<kIsGroupMode, FmhaBwdOGradDotOGroupModeKargs, FmhaBwdOGradDotOBatchModeKargs>
◆ ODataType
template<typename FmhaBwdOGradDotO_>
◆ OGradDataType
template<typename FmhaBwdOGradDotO_>
◆ BlockSize()
template<typename FmhaBwdOGradDotO_>
◆ GetName()
template<typename FmhaBwdOGradDotO_>
◆ GetSmemSize()
template<typename FmhaBwdOGradDotO_>
◆ GetTileIndex()
template<typename FmhaBwdOGradDotO_>
◆ GridSize()
template<typename FmhaBwdOGradDotO_>
◆ MakeKargs() [1/2]
template<typename FmhaBwdOGradDotO_>
template<bool Cond = !kIsGroupMode>
| CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::MakeKargs |
( |
const void * | o_ptr, |
|
|
const void * | do_ptr, |
|
|
void * | d_ptr, |
|
|
float | p_undrop, |
|
|
ck_tile::index_t | seqlen_q, |
|
|
ck_tile::index_t | hdim_v, |
|
|
ck_tile::index_t | stride_do, |
|
|
ck_tile::index_t | stride_o, |
|
|
ck_tile::index_t | nhead_stride_do, |
|
|
ck_tile::index_t | nhead_stride_o, |
|
|
ck_tile::index_t | nhead_stride_d, |
|
|
ck_tile::index_t | batch_stride_do, |
|
|
ck_tile::index_t | batch_stride_o, |
|
|
ck_tile::index_t | batch_stride_d ) |
|
inlinestaticconstexpr |
◆ MakeKargs() [2/2]
template<typename FmhaBwdOGradDotO_>
template<bool Cond = kIsGroupMode>
| CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdOGradDotOKernel< FmhaBwdOGradDotO_ >::MakeKargs |
( |
const void * | o_ptr, |
|
|
const void * | do_ptr, |
|
|
void * | d_ptr, |
|
|
float | p_undrop, |
|
|
const void * | seqstart_q_ptr, |
|
|
const void * | seqlen_q_ptr, |
|
|
const void * | cu_seqlen_q_ptr, |
|
|
ck_tile::index_t | hdim_v, |
|
|
ck_tile::index_t | stride_do, |
|
|
ck_tile::index_t | stride_o, |
|
|
ck_tile::index_t | nhead_stride_do, |
|
|
ck_tile::index_t | nhead_stride_o, |
|
|
ck_tile::index_t | nhead_stride_d ) |
|
inlinestaticconstexpr |
◆ operator()()
template<typename FmhaBwdOGradDotO_>
◆ kBlockPerCu
template<typename FmhaBwdOGradDotO_>
◆ kBlockSize
template<typename FmhaBwdOGradDotO_>
◆ kIsGroupMode
template<typename FmhaBwdOGradDotO_>
◆ kM0
template<typename FmhaBwdOGradDotO_>
◆ kPadHeadDimV
template<typename FmhaBwdOGradDotO_>
◆ kPadSeqLenQ
template<typename FmhaBwdOGradDotO_>
◆ kVHeaddim
template<typename FmhaBwdOGradDotO_>
The documentation for this struct was generated from the following file: